In [20]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
import torch_geometric as pyg

import sklearn.metrics as mt

from tqdm.notebook import tqdm

import clip_graph as cg

import utils as ut

In [21]:
os.chdir(os.path.expanduser('~/lizaixi/congrat-copy'))

In [22]:
device = 'cpu'

In [23]:
pl.seed_everything(2969591811)

Seed set to 2969591811


2969591811

# What should we evaluate?

In [None]:
datasets = {
    'pubmed': {
        'svd_init_dataset': 'configs/eval-datasets/pubmed/gassocausal.yaml',
        'svd_init_baseline': 'lightning_logs/gnn-pretrain/pubmed/version_2/',
        'svd_init_key': 'x',
        
        'models': {
            'causal': {
                'base': 'lightning_logs/clip-graph/inductive-causal/pubmed/version_102/',
                # 'sim10': 'lightning_logs/clip-graph/inductive-causal/pubmed/version_21/',
            },
        },
    },
}

# Do the evaluation

In [25]:
def test(z, pos_edge_index, neg_edge_index=None, eps=1e-15):
    if neg_edge_index is None:
        neg_edge_index = pyg.utils.negative_sampling(pos_edge_index, z.size(0))
    
    pos_y = z.new_ones(pos_edge_index.size(1))
    neg_y = z.new_zeros(neg_edge_index.size(1))
    y = torch.cat([pos_y, neg_y], dim=0).long()
    y = y.detach().cpu().numpy()

    decoder = pyg.nn.models.autoencoder.InnerProductDecoder()
    pos_dec = decoder(z, pos_edge_index, sigmoid=True)
    neg_dec = decoder(z, neg_edge_index, sigmoid=True)
    pred = torch.cat([pos_dec, neg_dec], dim=0)
    pred = pred.detach().cpu().numpy()

    return {
        'auc': mt.roc_auc_score(y, pred),
        'ap': mt.average_precision_score(y, pred),
        
        'recon': (
            -torch.log(pos_dec + eps).mean() +
            -torch.log(1 - neg_dec + eps).mean()
        ).item(),
        
        # very good scores from our model, but poorly calibrated;
        # let's just report the AUC/AP
        # 'accuracy': mt.accuracy_score(y, pred > 0.5),
        # 'precision': mt.precision_score(y, pred > 0.5),
        # 'recall': mt.recall_score(y, pred > 0.5),
        # 'f1': mt.f1_score(y, pred > 0.5),
    }

In [31]:
results = []

for dataset, paths in tqdm(datasets.items()):
    #
    # Dataset and specific objects to input to models
    #
    
    #clip_graph.data.datamodule.PubmedGraphTextDataModule
    dm = cg.utils.datamodule_from_yaml(paths['svd_init_dataset'])['dm']

    tx = getattr(dm.train_dataset.dataset.graph_data, paths['svd_init_key']).to(device)
    tw = dm.train_dataset.dataset.graph_data.edge_attr.to(device)
    tei = dm.train_dataset.dataset.graph_data.edge_index.to(device)
    tnei = dm.train_dataset.dataset.graph_data.neg_edge_index.to(device)
    print(tx, tei, tnei)
    print('y'*100)
    vx = getattr(dm.test_dataset.dataset.graph_data, paths['svd_init_key']).to(device)
    vw = dm.test_dataset.dataset.graph_data.edge_attr.to(device)
    vei = dm.test_dataset.dataset.graph_data.edge_index.to(device)
    vnei = dm.test_dataset.dataset.graph_data.neg_edge_index.to(device)

    #
    # Baselines
    #
    
    ## Fine-tuned for graph autoencoding
    gn_model = cg.scoring.interpret_ckpt_dir(paths['svd_init_baseline'], dm)['model'].model.encoder
    gn_model = gn_model.to(device)
    
    ## Same architecture, randomly initialized, totally untrained
    ckpt = cg.scoring.interpret_ckpt_dir(paths['svd_init_baseline'], dm)
    cls = getattr(cg.models, ckpt['config']['model']['init_args']['model_class_name'])
    params = ckpt['config']['model']['init_args']['model_params']
    bl_model = cls(**params)
    bl_model = bl_model.to(device)

    #
    # Generate embeddings
    #

    embs = {}

    ## First, baselines
    with torch.no_grad():
        embs[f'{dataset}-baseline'] = gn_model(vx, vei)['output']
        embs[f'{dataset}-untrained'] = bl_model(vx, vei)['output']
        
    ## Other models
    for lmtype in tqdm(paths['models'].keys()):
        for mod, path in tqdm(paths['models'][lmtype].items()):
            # print("--_____________________")
            # print(path, dm)
            #lightning_logs/clip-graph/inductive-causal/pubmed/version_21/ 
            # {Train dataloader: size=511}
            # {Validation dataloader: size=32}
            # {Test dataloader: size=84}
            # {Predict dataloader: None}
            cg_model = cg.scoring.interpret_ckpt_dir(path, dm)['model'].model
            cg_model = cg_model.to(device)
            cg_model.embed_nodes(vx, vei, vw)
            embs[f'{dataset}-{lmtype}_{mod}'] = F.normalize(cg_model.embed_nodes(vx, vei, vw), p=2, dim=1)
    
    res = pd.Series({
        k : test(v, vei, vnei)
        for k, v in tqdm(embs.items())
    }).apply(pd.Series)
    
    res['dataset'] = res.index.str.split('-').map(lambda s: s[0])
    res['model'] = res.index.str.split('-').map(lambda s: s[1])
    res = res.reset_index(drop=True).set_index(['dataset', 'model'])
    
    results += [res]

results = pd.concat(results, axis=0)
results = results.sort_index()

results.to_csv('data/link-prediction-eval.csv', index=True)

  0%|          | 0/1 [00:00<?, ?it/s]

batch_size 72
in_get_graph_object
in_get_graph_object_pickle_load


  return torch.load(io.BytesIO(b))


in_graph_text_dataset_init_super
in_graph_dataset_mixin_init
in_text_dataset_init
in_text_dataset_init_super
in_text_dataset_init_tokenizer
in_text_dataset_init_tokenizer_pad_token
in_text_dataset_init_tokenizer_params
in_graph_dataset_mixin_init_super
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
torch.Size([2, 122178])
torch.Size([122178, 1])
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
in_graph_dataset_mixin_init_graph_data
new_edge_attr torch.Size([122178, 1])
pre_edge_attr torch.Size([122178, 1])
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
in_graph_dataset_mixin_init_drop_isolates
in_graph_text_dataset_init
in_graph_text_dataset_init_unique_text_node_ids
in_setup_graph_text_dataset_extract_subgraph_with_texts
node_mask tensor([True, True, True,  ..., True, True, True]) --------------------------------------------------
node_idx tensor([    0,     1,     2,  ..., 19713, 19714, 19715]) --------------------------

Seed set to 2969591811


in_graph_dataset_mixin_compute_mutuals
in_graph_dataset_mixin_compute_mutuals_A_to_device
torch.cuda.is_available(): True
torch.Size([713, 713]) --------------------------------------------------
in_graph_dataset_mixin_compute_mutuals_with_torch_no_grad
in_graph_dataset_mixin_compute_mutuals_F_normalize
in_graph_dataset_mixin_compute_mutuals_mutual_T
in_graph_dataset_mixin_compute_mutuals_mutual_cpu
in_setup_graph_text_dataset_compute_mutuals
in_graph_dataset_mixin_compute_mutuals
in_graph_dataset_mixin_compute_mutuals_A_to_device
torch.cuda.is_available(): True
torch.Size([1996, 1996]) --------------------------------------------------
in_graph_dataset_mixin_compute_mutuals_with_torch_no_grad
in_graph_dataset_mixin_compute_mutuals_F_normalize
in_graph_dataset_mixin_compute_mutuals_mutual_T
in_graph_dataset_mixin_compute_mutuals_mutual_cpu
in_setup_graph_text_dataset_compute_mutuals
tensor([[-0.2111,  0.8552, -0.3588,  ...,  0.1306, -1.5901, -0.3376],
        [-0.2444, -0.1007, -0.0971

Seed set to 2969591811


++++++++++++++++++++ lightning_logs/gnn-pretrain/pubmed/version_2/checkpoints/epoch=19-step=20.ckpt <class 'clip_graph.lit.LitGAE'>


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Seed set to 2969591811


++++++++++++++++++++ lightning_logs/clip-graph/inductive-causal/pubmed/version_102/checkpoints/epoch=3-step=36748.ckpt <class 'clip_graph.gassolit.LitClipGraph'>


Seed set to 2969591811


++++++++++++++++++++ lightning_logs/clip-graph/inductive-causal/pubmed/version_21/checkpoints/epoch=7-step=73496.ckpt <class 'clip_graph.lit.LitClipGraph'>


TypeError: ClipGraph.embed_nodes() takes 3 positional arguments but 4 were given

# Examine results

In [8]:
results = pd.read_csv('data/link-prediction-eval.csv')
results = results.set_index(['dataset', 'model'])

## Display

In [9]:
with pd.option_context('display.max_rows', None):
    display(results)

Unnamed: 0_level_0,Unnamed: 1_level_0,auc,ap,recon
dataset,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
pubmed,baseline,0.759655,0.800513,1.767621
pubmed,causal_base,0.962295,0.953642,1.271247
pubmed,causal_sim10,0.962773,0.954242,1.274712
pubmed,untrained,0.536332,0.520224,29.225744


## Tables for paper

In [15]:
mods = [
    'baseline', 'untrained',
    'causal_base', 
]

tmp = results.loc[pd.IndexSlice[:, mods], :].sort_index()
tmp = tmp['auc'].reset_index().copy()

model_map = {
    **{
        k : k
        for k in tmp['model'].unique()
        if k not in ('baseline', 'untrained')
    },
    
    **{
        'baseline': 'baseline_svd',
        'untrained': 'baseline_untrained',
    }
}

tmp['model'] = tmp['model'].map(model_map)
tmp['type'] = tmp['model'].apply(lambda s: s.split('_')[0])
tmp['model'] = tmp['model'].apply(lambda s: s.split('_')[1])

tmp = tmp.loc[tmp['model'] != 'untrained', :]

tmp = tmp.set_index(['dataset', 'type', 'model'])
tmp = tmp.sort_index()
tmp = tmp.unstack(0)
tmp.columns = tmp.columns.droplevel(0)
tmp = tmp.loc[['causal', 'baseline'], :]

tmp.index = tmp.index.set_levels(tmp.index.levels[0].map({
    'causal': 'Causal',
    'baseline': 'GNN Autoencoder',
}), level=0)

tmp.index = tmp.index.set_levels(tmp.index.levels[1].map({
    'base': r'$\alpha = 0$',
    'sim10': r'$\alpha = 0.1$',
    'svd': 'SVD',
    'untrained': 'Untrained GNN',
}), level=1)

tmp.index.names = ['', '']
tmp.index = tmp.index.swaplevel()
tmp = tmp.sort_index()

tmp = tmp[['pubmed']]

tmp.columns = tmp.columns.map({
    'pubmed': 'Undirected-Pubmed',
    'trex': 'Undirected-TRex',
    'twitter_small': 'Undirected-Twitter',

    'pubmed_directed': 'Directed-Pubmed',
    'trex_directed': 'Directed-TRex',
    'twitter_small_directed': 'Directed-Twitter'
})

tmp.columns = pd.MultiIndex.from_frame(pd.DataFrame(tmp.columns.to_series().reset_index(drop=True).str.split('-').tolist()))

tmp.columns.name = ''
tmp.columns.names = ['', '']

tmp.index.names = ['alpha', 'txt']
tmp = tmp.reset_index()
tmp['txt'] = 'ConGraT-' + tmp['txt']
tmp['txt'] = tmp['txt'] + ' (' + tmp['alpha'] + ')'

tmp.loc[tmp['txt'] == 'ConGraT-GNN Autoencoder (Baseline)', 'txt'] = 'GAT Autoencoder (Baseline)'

tmp = tmp.drop('alpha', axis=1)

# tmp = tmp.apply(np.roll, shift=1)

tmp = tmp.set_index('txt')
tmp.index.name = ''

  tmp = tmp.drop('alpha', axis=1)


In [16]:
def bold_except_last_row(s):
    return pd.concat([
        ut.bold_above_thresh(s[:-1], s[-2]),
        pd.Series([''], index=[s.index[-1]]),
    ])

tab = tmp.style \
    .format(precision=3, na_rep='--') \
    .apply(bold_except_last_row, axis=0)
    
with pd.option_context('display.html.use_mathjax', True):
    display(tab)

  ut.bold_above_thresh(s[:-1], s[-2]),


Unnamed: 0_level_0,Undirected
Unnamed: 0_level_1,Pubmed
,
ConGraT-Causal ($\alpha = 0$),0.962
ConGraT-GNN Autoencoder (SVD),0.76


In [17]:
print(tab.to_latex(
        hrules = True,
        column_format = 'lcccccc',
        position = 'ht',
        label = 'tab:link-prediction',
        multicol_align = '|c',
        position_float = 'centering',
        environment = 'table*',
        convert_css = True,
    ))

\begin{table*}[ht]
\centering
\label{tab:link-prediction}
\begin{tabular}{lcccccc}
\toprule
 & Undirected \\
 & Pubmed \\
 &  \\
\midrule
ConGraT-Causal ($\alpha = 0$) & 0.962 \\
ConGraT-GNN Autoencoder (SVD) & 0.760 \\
\bottomrule
\end{tabular}
\end{table*}



  ut.bold_above_thresh(s[:-1], s[-2]),
