In [None]:
import torch
!pip install torch-geometric torch-scatter torch-sparse -f https: // data.pyg.org/whl/torch-{torch.__version__}.html

In [None]:
!rm gnn-dissect -r
!git clone https://github.com/ManthanDalmia/gcNeuron.git

## Setup

In [None]:
%cd gnn-dissect/src/

In [None]:
!pip -q install shap pyvis rdkit karateclub torch_explain

In [None]:
!pip install matplotlib == 3.5.1
!pip install networkx == 2.6.3

In [None]:
!pip install dill

In [None]:
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import k_hop_subgraph, to_networkx

In [None]:
import copy
import dill
import importlib
import json
import math
import matplotlib.pyplot as plt
import node_explainer
import numpy as np
import pandas as pd
import seaborn as sns
import shap
import torch
import vis
from collections import defaultdict
from concept_utils import *
from graph_utils import *
from matplotlib import rc
from neuron_metrics import *
from pipeline import load_dataset
from pipeline import train_standard_model
from torch_geometric.data import Data
from tqdm.notebook import tqdm
from vis import visualise_graph

import concept_ranker
import concepts

In [None]:
import dill


def strip(concepts):
    lightweight = {}
    for k, dic in concepts.items():
        dic_ = {}
        for k_, (_, v1, v2) in dic.items():
            dic_[k_] = (None, v1, v2)
        lightweight[k] = dic_
    return lightweight


def save_concepts(concepts, fname: str):
    import dill
    #concepts = strip(concepts)
    f = open(f'../concepts/{fname}.pkl', 'wb')
    dill.dump(concepts, f, protocol=dill.HIGHEST_PROTOCOL)


def load_concepts(fname: str):
    f = open(f'../concepts/{fname}.pkl', 'rb')
    return dill.load(f)

## MUTAG

In [None]:
train_loader_mutag, test_loader_mutag, _, dataset_mutag, train_dataset, test_dataset, _ = load_dataset('MUTAG')
model_mutag = train_standard_model('MUTAG', 'GIN', fold=0)
neuron_concepts = model_mutag.concept_search('MUTAG', train_dataset, depth=3, top=64, augment=False, level=1)

In [None]:
save_concepts(neuron_concepts, 'mutag')

### MUTAG: finding concepts

In [None]:
cleaned_concepts, distilled = clean_concepts(neuron_concepts)
neuron_concepts

In [None]:
units = [22, 0, 50]
graphs = [11, 61, 151]
for g_ in graphs:
    for n in units:
        g = dataset_mutag.get(g_)
        final_mask, node_values = model_mutag.expl_gp_neurons(g, 1, debug=True, rank=set([n]), gamma=1000,
                                                              sigma=get_ths(cleaned_concepts),
                                                              names=get_names(cleaned_concepts),
                                                              scores=get_scores(cleaned_concepts), cum=True,
                                                              show_labels=False, show_node_mask=True, explore=True,
                                                              as_molecule=True, show_contribs=True, force=True)
        vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask,
                       node_values=None, show_labels=False,
                       anchor=get_ths(cleaned_concepts)[n] * 2,
                       as_molecule=True,
                       custom_name=f'mutag_concepts/mutag-graph{g_}_neuron{n}.png')


### MUTAG: global explanations

In [None]:
from concept_ranker import by_weight_adv, by_weight
import importlib
import concept_utils

importlib.reload(concept_utils)

target_class = 1
neurons, vals = by_weight(model_mutag, target_class)
task = 'mutag'
dev = model_mutag.device


def get_global_vis(n):
    best_g = None
    best_s = float('-inf')
    for i in range(len(train_dataset)):
        g = train_dataset[i]
        pf = model_mutag.partial_forward(g.x.to(dev), g.edge_index.to(dev)).detach().cpu()
        score = pf[:, n].max()
        if score > best_s:
            best_s = score
            best_g = g
    return best_g, score


seen_concepts = set()

for n, v in zip(neurons, vals):
    if v <= 0:
        break
    if n in cleaned_concepts:
        if cleaned_concepts[n][1][2] in seen_concepts:
            continue

        g, _ = get_global_vis(n)
        final_mask, node_values = model_mutag.expl_gp_neurons(g, target_class, debug=True, rank=set([n]), gamma=1000,
                                                              sigma=get_ths(cleaned_concepts),
                                                              names=concept_utils.get_names(cleaned_concepts),
                                                              scores=get_scores(cleaned_concepts), cum=True,
                                                              show_labels=False, show_node_mask=True, explore=True,
                                                              as_molecule=True, show_contribs=True, force=True)
        vis.show_graph(Data(g.x, g.edge_index,
                            edge_attr=g.edge_attr), final_mask, node_values=node_values, show_labels=False,
                       anchor=get_ths(cleaned_concepts)[n] * 3, as_molecule=True,
                       custom_name=f'{task}_global/{task}_global_class{target_class}_neuron{n}_{v : .4f}.png')
        seen_concepts.add(cleaned_concepts[n][1][2])

In [None]:
graphs = [11, 61, 151]
unit = 50
concepts = set([unit])

for g_ in graphs:
    g = dataset_mutag.get(g_)
    final_mask, node_values = model_mutag.expl_gp_neurons(g, 1, debug=True, rank=set([unit]), gamma=1000,
                                                          sigma=get_ths(cleaned_concepts),
                                                          names=get_names(cleaned_concepts),
                                                          scores=get_scores(cleaned_concepts), cum=True,
                                                          show_labels=False, show_node_mask=True, explore=True,
                                                          as_molecule=True, show_contribs=True, force=True)
    vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask, node_values=node_values,
                   show_labels=False, anchor=get_ths(cleaned_concepts)[unit] * 2.0,
                   custom_name=f'mutag_graph{g_}unit{unit}', as_molecule=False)

In [None]:
import vis

g = dataset_mutag.get(153)

final_mask, node_values = model_mutag.expl_gp_neurons(g, 0, debug=True, rank=64, gamma=1000,
                                                      sigma=get_ths(cleaned_concepts),
                                                      names=get_names(cleaned_concepts),
                                                      scores=get_scores(cleaned_concepts), cum=True, show_labels=False,
                                                      show_node_mask=True, explore=True, as_molecule=True,
                                                      show_contribs=True, entropic=False)

In [None]:
vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask, node_values=None, show_labels=False,
               anchor=None, custom_name=f'mutag_graph153unit{unit}', as_molecule=False)

## IMDB

In [None]:
model = train_standard_model('IMDB', 'GIN', fold=0)

train_loader, test_loader, val_loader, dataset, train_dataset, test_dataset, val_dataset = load_dataset('IMDB')

In [None]:
neuron_concepts = model.concept_search('IMDB', train_dataset, depth=2, top=64, augment=False, omega=[10, 20, 20])

In [None]:
save_concepts(neuron_concepts, 'IMDB')

In [None]:
cleaned_concepts, distilled = clean_concepts(neuron_concepts)

### IMDB: searching for concepts

In [None]:
import importlib
import vis

dev = model.device
units = [21, 12, 7]
graphs = [631, 714, 58]
for g_ in graphs:
    for n in units:
        g = dataset[g_]
        final_mask, node_values = model.expl_gp_neurons(g, 1, debug=False, rank=set([n]), gamma=1000,
                                                        sigma=get_ths(cleaned_concepts),
                                                        names=get_names(cleaned_concepts),
                                                        scores=get_scores(cleaned_concepts), cum=True,
                                                        show_labels=False, show_node_mask=True, explore=True,
                                                        as_molecule=True, show_contribs=True, force=True)
        vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask,
                       node_values=None, show_labels=False,
                       anchor=get_ths(cleaned_concepts)[n],
                       as_molecule=False,
                       custom_name=f'imdb_concepts/imdb-graph{g_}_neuron{n}.svg')


In [None]:
dev = model.device

import node_explainer
import importlib

graphs = [631, 714, 58]
unit = 59
concepts = set([unit])
y = 1

for g_ in graphs:
    g = dataset[g_]

    final_mask, node_values = model.expl_gp_neurons(g, y, debug=True, gamma=1030, sigma=get_ths(cleaned_concepts),
                                                    names=get_names(cleaned_concepts),
                                                    scores=get_scores(cleaned_concepts), rank=concepts, cum=False,
                                                    show_node_mask=True, anchor=0.44, show_contribs=True, force=True,
                                                    explore=True)
    vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask, node_values=node_values,
                   show_labels=False, anchor=get_ths(cleaned_concepts)[unit] * 1.2, as_molecule=False)

### IMDB: global explanations

In [None]:
from concept_ranker import by_weight

target_class = 1
neurons, vals = by_weight(model, target_class)

In [None]:

def get_global_vis(n):
    best_g = None
    best_s = float('-inf')
    for i in range(len(dataset)):
        g = dataset[i]
        if g.x.shape[0] > 100:
            continue
        pf = model.partial_forward(g.x.to(dev), g.edge_index.to(dev)).detach().cpu()
        score = pf[:, n].max().item()
        if score > best_s:
            best_s = score
            best_g = g
    return best_g, best_s


seen_concepts = set()

for n, v in zip(neurons, vals):
    if v <= 0:
        break
    if n in cleaned_concepts:
        if cleaned_concepts[n][1][2] in seen_concepts:
            continue
        print(f'Neuron {n} Concept {cleaned_concepts[n]}')
        g, sc = get_global_vis(n)
        final_mask, node_values = model.expl_gp_neurons(g, target_class, debug=True, rank=set([n]), gamma=1000,
                                                        sigma=get_ths(cleaned_concepts),
                                                        names=get_names(cleaned_concepts),
                                                        scores=get_scores(cleaned_concepts), cum=True,
                                                        show_labels=False, show_node_mask=True, explore=True,
                                                        as_molecule=False, show_contribs=True, force=True)
        vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask,
                       node_values=node_values, show_labels=False,
                       anchor=get_ths(cleaned_concepts)[n], as_molecule=False,
                       custom_name=f'imdb_global/imdb_global_class{target_class}_neuron{n}_{v : .4f}.svg')
        seen_concepts.add(cleaned_concepts[n][1][2])

## REDDIT

In [None]:
torch.cuda.empty_cache()

In [None]:
model = train_standard_model('REDDIT', 'GCN', fold=0)
train_loader, test_loader, _, dataset, train_dataset, test_dataset, val_dataset = load_dataset('REDDIT')

In [None]:
neuron_concepts = model.concept_search('REDDIT', train_dataset, depth=2, top=64, augment=False, omega=[10, 25, 20])

In [None]:
save_concepts(neuron_concepts, 'REDDIT')

### REDDIT: searching for concepts

In [None]:
cleaned_concepts, distilled = clean_concepts(neuron_concepts)

In [None]:
units = [46, 51]
graphs = [865, 271, 534]
for g_ in graphs:
    for n in units:
        g = dataset[g_]
        final_mask, node_values = model.expl_gp_neurons(g, 1, debug=True, rank=set([n]), gamma=1000,
                                                        sigma=get_ths(cleaned_concepts),
                                                        names=get_names(cleaned_concepts),
                                                        scores=get_scores(cleaned_concepts), cum=True,
                                                        show_labels=False, show_node_mask=True, explore=True,
                                                        as_molecule=True, show_contribs=True, force=True)
        vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask,
                       node_values=node_values, show_labels=False,
                       anchor=get_ths(cleaned_concepts)[n],
                       as_molecule=False,
                       custom_name=f'reddit_concepts/reddit-graph{g_}_neuron{n}.svg')


In [None]:
dev = model.device
for j in range(64):
    g = dataset[184]
    pf = model.partial_forward(g.x.to(dev), g.edge_index.to(dev)).detach().cpu()
    pos_count = (pf[:, j] > 0).sum()
    if pos_count > 2:
        print(f'{j}  {pos_count}')


In [None]:
for g in range(1000):
    if 190 < dataset[g].x.shape[0] < 200 and dataset[g].y == 0:
        print(g)

In [None]:
import vis

graphs = [280]
unit = 51
concepts = set([unit])

for g_ in graphs:
    g = dataset[g_]
    final_mask, node_values = model.expl_gp_neurons(g, 0, debug=True,
                                                    rank=set([unit]), gamma=1000, sigma=get_ths(cleaned_concepts),
                                                    names=get_names(cleaned_concepts),
                                                    scores=get_scores(cleaned_concepts), cum=True, show_labels=False,
                                                    show_node_mask=True, explore=True, as_molecule=False,
                                                    show_contribs=True, force=True)
    vis.show_graph(Data(g.x, g.edge_index, edge_attr=g.edge_attr), final_mask, node_values=node_values,
                   show_labels=False, anchor=get_ths(cleaned_concepts)[unit] * 1.0,
                   custom_name=f'reddit_graph{g_}unit{unit}', as_molecule=False)

### REDDIT: global explanations

In [None]:
from concept_ranker import by_weight_adv

target_class = 1
neurons, vals = by_weight_adv(model, target_class)
task = 'reddit'


def get_global_vis(n):
    best_g = None
    best_s = float('-inf')
    for i in range(len(train_dataset)):
        g = dataset[i]
        if g.x.shape[0] > 300:
            continue
        pf = model.partial_forward(g.x.to(dev), g.edge_index.to(dev)).detach().cpu()
        score = pf[:, n].max()
        if score > best_s:
            best_s = score
            best_g = g
    return best_g, score


seen_concepts = set()

for n, v in zip(neurons, vals):
    if v <= 0:
        break
    if n in cleaned_concepts:
        if cleaned_concepts[n][0] in seen_concepts:
            continue

        g, _ = get_global_vis(n)
        final_mask, node_values = model.expl_gp_neurons(g, target_class, debug=True, rank=set([n]), gamma=1000,
                                                        sigma=get_ths(cleaned_concepts),
                                                        names=get_names(cleaned_concepts),
                                                        scores=get_scores(cleaned_concepts), cum=True,
                                                        show_labels=False, show_node_mask=True, explore=True,
                                                        as_molecule=False, show_contribs=True, force=True)
        vis.show_graph(Data(g.x, g.edge_index,
                            edge_attr=g.edge_attr), final_mask, node_values=node_values, show_labels=False,
                       anchor=get_ths(cleaned_concepts)[n], as_molecule=False,
                       custom_name=f'{task}_global/{task}_global_class{target_class}_neuron{n}_{v : .4f}.svg')
        seen_concepts.add(cleaned_concepts[n][0])

# CONCEPTS AT DIFFERENT LAYERS

In [None]:
for epochs in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170,
               180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390,
               400]:
    train_loader_mutag, test_loader_mutag, _, dataset_mutag, train_dataset, test_dataset, _ = load_dataset('MUTAG')
    model = train_standard_model('MUTAG', 'GIN', custom_name=f'mutag-train-{epochs}-epochs', custom_epochs=epochs)
    neuron_concepts = model.concept_search('MUTAG', train_dataset, depth=3, top=64, augment=False, level=1)
    save_concepts(neuron_concepts, f'MUTAG {epochs}')
    torch.cuda.empty_cache()

    model = train_standard_model('REDDIT', 'GCN', custom_name=f'reddit-train-{epochs}-epochs', custom_epochs=epochs,
                                 overwrite=True)
    train_loader, test_loader, _, dataset, train_dataset, test_dataset, val_dataset = load_dataset('REDDIT')
    neuron_concepts = model.concept_search('REDDIT', train_dataset, depth=2, top=64, augment=False,
                                           omega=[10, 25, 20])
    save_concepts(neuron_concepts, f'REDDIT {epochs}')
    torch.cuda.empty_cache()

    model = train_standard_model('PROTEINS', 'GIN', custom_name=f'protein-train-{epochs}-epochs', custom_epochs=epochs,
                                 custom_es=100000, overwrite=True)
    train_loader, test_loader, _, dataset, train_dataset, test_dataset, val_dataset = load_dataset('PROTEINS')
    neuron_concepts = model.concept_search('PROTEINS', train_dataset, depth=4, top=64, augment=False,
                                           omega=[15, 25, 20])
    save_concepts(neuron_concepts, f'PROTEINS {epochs}')
    torch.cuda.empty_cache()

In [None]:
for level in range(1, 11):
    for fold in range(3):
        train_loader_mutag, test_loader_mutag, _, dataset_mutag, train_dgataset, test_dataset, _ = load_dataset('MUTAG')
        model = train_standard_model('MUTAG', 'GIN', custom_layers=10, custom_name=f'mutag-deep-fold{fold}',
                                     overwrite=True)
        neuron_concepts = model.concept_search('MUTAG', train_dataset, depth=level, top=64, augment=False, level=level)
        save_concepts(neuron_concepts, f'MUTAG-deep-level{level}-fold{fold}')
        torch.cuda.empty_cache()