# Evaluating the difference between connection methodologies and given attention weights

### Imports

In [1]:
import os, sys
import pickle
from typing import List, Tuple
from collections.abc import Callable
import time
import datetime as dt
from tqdm.notebook import tqdm

In [2]:
import pandas as pd
import numpy as np
import networkx as nx

In [3]:
from scipy.stats import wasserstein_distance

In [4]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [5]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

from torch_geometric.nn import GraphConv, global_add_pool, DenseGraphConv, dense_diff_pool
import torch.nn.functional as F
from torch.nn import NLLLoss

from torch_geometric.utils import to_dense_adj, to_networkx
from torch_geometric.transforms import ToDense

In [6]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.figsize'] = 15, 8.27

import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'seaborn'

In [7]:
from ipywidgets import interact, interact_manual, FloatSlider

In [8]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.utils import load_dataset, ASSETS_DIR, CHECKPOINTS_DIR
from src.models import DiffPool, BaselineGNN
from src.train import train
from src.metrics import evaluate, TrainingMetrics, TestingMetrics

CONNECTION_DIR = '/Users/arnauddhaene/Downloads/'

### Fetching the necessary data

### Fetching the trained model

In [15]:
model_args = dict(num_classes=2, hidden_dim=64, node_features_dim=43)

model = BaselineGNN(layer_type='GAT', **model_args)

storage_path = os.path.join(ASSETS_DIR + 'models/',
                            'Baseline GNN with 5 GAT layers-2021-10-30 13:27:34.879462.pkl')

model.load_state_dict(torch.load(storage_path))

<All keys matched successfully>

### Fetching the data used for training

In [16]:
datafile = open(os.path.join(CHECKPOINTS_DIR, 'wasserstein_20_27_False_2021-11-01.pt'), 'rb')

dataset_train, dataset_test = pickle.load(datafile)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/arnauddhaene/development/lts4/graphmel/src/../data/checkpoints/wasserstein_20_27_False_2021-11-01.pt'

In [12]:
from collections import Counter

Counter(list(map(lambda g: g.y.item(), dataset_train))), Counter(list(map(lambda g: g.y.item(), dataset_test)))

(Counter({0: 45, 1: 26}), Counter({0: 9, 1: 9}))

In [13]:
def get_attention_weights(model: torch.nn.Module, graph: Data) -> torch.Tensor:
    
    x, edge_index = graph.x, graph.edge_index

    for step in range(len(model.convs) - 1):
        x = model.convs[step](x, edge_index)

    x, (edge_index, alpha) = model.convs[-1](x, edge_index, return_attention_weights=True)
    
    return edge_index, alpha

In [14]:
example = dataset_train[41]
@interact(example=dataset_train)
def show_attention(example: Data):
    edge_index, alpha = get_attention_weights(model, example)
    
    storage = []

    for (i, j), a in zip(edge_index.t().tolist(), alpha.flatten().tolist()):
        # print(f'{i} -> {j} has weight {a}')

        vol = abs(example.x[i, 1] - example.x[j, 1]).item()

        storage.append(dict(edge=f'{i} -> {j}', alpha=a, vol_ccm_d=vol))

    df = pd.DataFrame(storage)

    sns.scatterplot(data=df, x='alpha', y='vol_ccm_d')
    
    plt.show()

interactive(children=(Dropdown(description='example', options=(Data(x=[60, 44], edge_index=[2, 396], y=[1], nu…