In [1]:
import os
import os.path as osp
from os import environ
import time

import torch
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import CitationFull, Coauthor, Twitch
from gcn_conv import GCNConv
from torch_geometric.utils import negative_sampling, to_dense_adj, add_remaining_self_loops, degree, is_undirected

import seaborn
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.express as px
import math
import numpy as np

from scipy.stats import pearsonr

In [2]:
environ['CUDA_LAUNCH_BLOCKING'] = "1"
dataset_name = environ.get('dataset_name', 'Cora_ML')
dataset_name = 'Physics'
conv_type = environ.get('conv_type', 'sym')
k = 2

In [3]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, conv_type):
        super().__init__()
        self.convs = torch.nn.ModuleList([GCNConv(in_channels, hidden_channels, conv=conv_type)] + \
                                [GCNConv(hidden_channels, hidden_channels, conv=conv_type) for i in range(k - 2)] + \
                                [GCNConv(hidden_channels, out_channels, conv=conv_type)])

    def encode(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index).relu()
        return self.convs[-1](x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return prob_adj

In [4]:
def train(model, train_data):
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
        
    return loss

@torch.no_grad()
def test(model, data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

In [5]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [6]:
def load_dataset(dataset_name):
    if dataset_name in ["Cora", "Cora_ML", "CiteSeer", "PubMed"]:
        path = osp.join('.', 'data', 'CitationFull')
        dataset = CitationFull(path, name=dataset_name, transform=transform)
    elif dataset_name in ["CS", "Physics"]:
        path = osp.join('.', 'data', 'Coauthor')
        dataset = Coauthor(path, name=dataset_name, transform=transform)
    elif dataset_name in ["RU"]:
        path = osp.join('.', 'data', 'Twitch')
        dataset = Twitch(path, name="RU", transform=transform)
    else:
        raise ValueError

    return dataset

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      add_negative_train_samples=False),
])

dataset = load_dataset(dataset_name)
test_data = dataset[0][2]
num_labels = int(test_data.y.max()) + 1

# masks = []
x_list = [[[] for _ in range(10)] for _ in range(num_labels)]
y_list = [[[] for _ in range(10)] for _ in range(num_labels)]
pcc_list = [[[] for _ in range(10)] for _ in range(num_labels)]
auc_list = []

for seed_idx, seed in enumerate([34, 87, 120, 11, 93, 24, 25, 56, 49, 54]):
    seed_everything(0)
    
    dataset = load_dataset(dataset_name)
    train_data, val_data, test_data = dataset[0]
    
    seed_everything(seed)
    
    model = Net(dataset.num_features, 128, 64, conv_type).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()

    best_val_auc = final_test_auc = 0
    for epoch in range(1, 101):
        loss = train(model, train_data)
        val_auc = test(model, val_data)
        test_auc = test(model, test_data)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            final_test_auc = test_auc
#         if epoch % 20 == 0:
#             print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
#                   f'Test: {test_auc:.4f}')

    print(f'Final Test: {final_test_auc:.4f}')
    auc_list.append(final_test_auc)
    
    model.eval()
    rep = model.encode(test_data.x, test_data.edge_index)
    final_probs = model.decode_all(rep).detach().cpu()
    
    label_adj = to_dense_adj(test_data.edge_label_index, max_num_nodes=test_data.x.size(0))[0]
    intra_mask = (test_data.y.reshape(-1, 1) == test_data.y.reshape(1, -1)) & (label_adj == 1)
#     masks.append(intra_mask)
    
#     U, _, V = torch.pca_lowrank(rep, q=2)
#     rep_pca = U.detach().cpu().numpy()
#     groups = test_data.y.cpu().numpy()
#     for c in range(int(groups.max()) + 1):
#         plt.scatter(rep_pca[:, 0][groups == c], rep_pca[:, 1][groups == c])
#     plt.show()
    
    z = test_data.x
    for conv in model.convs:
        z = z @ conv.lin.weight.detach().t()

    taylor_rep = torch.zeros_like(z)
    adj = to_dense_adj(test_data.edge_index, max_num_nodes=test_data.x.size(0))[0]
    for c in range(int(test_data.y.max()) + 1):
        c_mask = torch.nonzero(test_data.y.flatten() == c).flatten()
        c_section = adj[c_mask.reshape(-1,1), c_mask]

        deg_i = c_section.sum(dim=1)
        deg_j = c_section.sum(dim=0)

        if conv_type == "sym":
            eigv_i = torch.sqrt(deg_i / deg_i.sum())
            eigv_j = torch.sqrt(deg_j / deg_j.sum())
        else:
            eigv_i = torch.ones_like(deg_i)
            eigv_j = deg_j / deg_j.sum()

        taylor_rep[c_mask] = eigv_i.reshape(-1, 1) @ eigv_j.reshape(1, -1) @ z[c_mask]
    
    pred_probs = model.decode_all(taylor_rep).detach().cpu()
#     y_idx = test_data.y.argsort()
#     plot_probs = pred_probs[y_idx][:, y_idx]
#     plot_mask = intra_mask[y_idx][:, y_idx]
#     plot_probs[~plot_mask] = 0
#     seaborn.heatmap(plot_probs.cpu().numpy())
#     plt.show()

    for c in range(int(test_data.y.max()) + 1):
        c_mask = ((test_data.y.reshape(-1, 1) == c) & (test_data.y.reshape(1, -1) == c)) & (label_adj == 1)
        
        x = final_probs[c_mask].flatten().numpy()
        y = pred_probs[c_mask].flatten().numpy()
        
        if len(x) <= 1:
            pcc = (0.0, 0.0)
        else:
            pcc = pearsonr(x, y)
        pcc_list[c][seed_idx].append(pcc[0])

#         plt.scatter(x, y)
#         plt.xlabel('Actual link prediction score')
#         plt.ylabel('Estimated link prediction score')
#         plt.title(dataset.name + ", " + r"$r = {:.3f}$, ".format(pcc[0]) + r"$p = {:.3f}$".format(pcc[1]))
#         plt.show()

        x_list[c][seed_idx] = x.tolist()
        y_list[c][seed_idx] = y.tolist()

Final Test: 0.9277


RuntimeError: CUDA out of memory. Tried to allocate 4.43 GiB (GPU 0; 11.91 GiB total capacity; 6.73 GiB already allocated; 3.29 GiB free; 7.79 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# t0 = masks[0]
# for t1 in masks[1:]:
#     assert (t0 == t1).all()

In [None]:
pcc_all = torch.tensor(pcc_list)
pcc_avg = pcc_all.mean().item()
pcc_std = pcc_all.std().item()
pcc_str = "{0:.3f} ± ".format(pcc_avg) + "{0:.3f}".format(pcc_std)
print(pcc_str)

In [None]:
auc_all = torch.tensor(auc_list)
auc_avg = auc_all.mean().item()
auc_std = auc_all.std().item()
auc_str = "{0:.3f} ± ".format(auc_avg) + "{0:.3f}".format(auc_std)
print(auc_str)

In [None]:
axis_titles = ["Actual link prediction score", "Estimated link prediction score"]

for idx, (a_list, b_list) in enumerate([(x_list, y_list), (y_list, x_list)]):
    
    plotly_objs = []
    
    min_avg = 0
    max_avg = 0
    
    for c in range(len(a_list)):
        
        x_all = torch.tensor(a_list[c])
        y_all = torch.tensor(b_list[c])
        
        if x_all.numel() == 0:
            continue

        for i in range(len(x_all)):
            x, y = x_all[i], y_all[i]
            rho, _, _, _ = np.linalg.lstsq(x.flatten().numpy()[:,np.newaxis], y.flatten().numpy())
            y_all[i] /= rho[0]

        x_avg = x_all.mean(dim=0)
        y_avg = y_all.mean(dim=0)
        
        min_avg = min(min_avg, x_avg.min().item())
        min_avg = min(min_avg, y_avg.min().item())
        max_avg = max(max_avg, x_avg.max().item())
        max_avg = max(max_avg, y_avg.max().item())

        x_sort = torch.argsort(x_avg)
        x_avg = x_avg[x_sort].tolist()

        y_max = y_all.max(dim=0).values[x_sort].tolist()
        y_min = y_all.min(dim=0).values[x_sort].tolist()
        y_avg = y_avg[x_sort].tolist()
        
        plotly_objs.extend([
            go.Scatter(
                x=x_avg,
                y=y_avg,
                mode='markers',
                showlegend=False
            ),
            go.Scatter(
                x=x_avg+x_avg[::-1], # x, then x reversed
                y=y_max+y_min[::-1], # upper, then lower reversed
                fill='toself',
                fillcolor='rgba(0,100,80,0.2)',
                line=dict(color='rgba(255,255,255,0)'),
                hoverinfo="skip",
                showlegend=False
            )
        ])

    diag_y = [min_avg, max_avg]
    diag_x = diag_y
    
    plotly_objs.append(go.Scatter(
                x=diag_x,
                y=diag_y,
                mode='lines',
                line={'dash': 'dash', 'color': 'black'},
                showlegend=False
            ))
  
    # avoid loading mathjax text
    fig=px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16])
    fig.write_image("plots/{}_{}_{}.pdf".format(dataset.name, idx, conv_type))
    time.sleep(2)

    if conv_type == "sym":
        plot_title = r"$\Phi_s \text{ on " + dataset.name + r", test AUC: }" + auc_str + r"$"
    else:
        plot_title = r"$\Phi_r \text{ on " + dataset.name + r"}$"

    fig = go.Figure(plotly_objs).update_layout(
        xaxis_title=r"$\text{Average " + axis_titles[idx].lower() + r"}$",
        yaxis_title=r"$\text{" + axis_titles[1 - idx] + r"}$"
    )
    fig.update_layout(title_text=plot_title, title_x=0.5)
    fig.update_layout(
        margin=dict(l=20, r=20, t=30, b=20),
    )
    fig.add_annotation(dict(x=0.75,
                            y=0.2,
                            text=r"$r = {}$".format(pcc_str),
                            showarrow=False,
                            textangle=0,
                            xanchor='left',
                            xref="paper",
                            yref="paper"))
    fig.write_image("plots/{}_{}_{}.pdf".format(dataset.name, idx, conv_type))