# Visualize Counterfactual Graphs

In [10]:
import os
import pickle
from pprint import pprint
import shutil

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
import torch

In [2]:
INPUT = "../data/syn1_dataset.pkl"
OUTPUT = "../output/syn1"
os.makedirs(OUTPUT, exist_ok=True)

## Data

In [3]:
with open(INPUT, "rb") as file:
    data = pickle.load(file)

In [4]:
print(type(data))
print()
first_key = list(data.keys())[0]
pprint(data[first_key])

<class 'dict'>

{'adj': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 1.],
       [0., 0., 0., ..., 1., 0., 1.],
       [0., 0., 0., ..., 1., 1., 0.]], dtype=float32),
 'cfs': [[27, 62, 'del'], [102, 28, 'add']],
 'target': 102}


In [30]:
# Reorder the cfs such that the source node label is always smaller than the destination.
# Makes plotting easier later on.
for node_id in data:
    for i, cf in enumerate(data[node_id]['cfs']):
        src, dest, action = cf[0], cf[1], cf[2]
        if src > dest:
            src, dest = dest, src
            data[node_id]['cfs'][i] = [src, dest, action]

In [31]:
pprint(data[first_key])

{'adj': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 1.],
       [0., 0., 0., ..., 1., 0., 1.],
       [0., 0., 0., ..., 1., 1., 0.]], dtype=float32),
 'cfs': [[27, 62, 'del'], [28, 102, 'add']],
 'target': 102}


## Visualize

### Graph

In [None]:
THRESHOLD = 15
OUT_PATH = OUTPUT + "/graph/"
if os.path.exists(OUT_PATH):
    shutil.rmtree(OUT_PATH)
os.makedirs(OUT_PATH)

In [22]:
# Iterate over the graphs.
for node_id in data:
    # Discard large graphs.
    if data[node_id]['adj'].sum() // 2 > THRESHOLD:
        continue
    graph = nx.from_numpy_array(data[node_id]['adj'])
    pos = nx.spring_layout(graph)
    # Assign colors to the nodes.
    color_map = list()
    for node in graph:
        if node == data[node_id]['target']:
            color_map.append("#F28C28")
        else:
            color_map.append("#00308F")
    nx.draw(G=graph, pos=pos, node_color=color_map, width=1.5)
    # Save the plot.
    plt.savefig(OUT_PATH + f"{node_id}" + ".png", format="png")
    # Clear the plot for next iteration.
    # Otherwise it will draw the new graph on top of it.
    plt.clf()

<Figure size 432x288 with 0 Axes>

### Colored Map

In [32]:
THRESHOLD = 15
OUT_PATH = OUTPUT + "/colored_graphs/"
if os.path.exists(OUT_PATH):
    shutil.rmtree(OUT_PATH)
os.makedirs(OUT_PATH)

In [None]:

#todo: Edge colors are not visible.
#todo: Graph and Colored graph do not have identical plots.

In [34]:
# Iterate over the graphs.
for node_id in data:
    # Discard large graphs.
    if data[node_id]['adj'].sum() // 2 > THRESHOLD:
        continue
    # Create the graph
    graph = nx.from_numpy_array(data[node_id]['adj'])
    pos = nx.spring_layout(graph)
    # Assign colors to the nodes.
    color_map = list()
    for node in graph:
        if node == data[node_id]['target']:
            color_map.append("#F28C28")
        else:
            color_map.append("#00308F")
    # Assign colors to the edges
    cfs = data[node_id]['cfs']
    cf_edges = [(cf[0], cf[1]) for cf in cfs]
    cf_actions = [cf[2] for cf in cfs]
    edge_color_map = list()
    for edge in graph.edges:
        if edge not in cf_edges:
            edge_color_map.append("black")
            continue
        index = cf_edges.index(edge)
        action = cf_actions[index]
        if action == 'del':
            edge_color_map.append("green")
        else:
            edge_color_map.append("red")
    # Draw
    nx.draw(
        G=graph,
        pos=pos,
        node_color=color_map,
        width=1.5,
        edge_color=edge_color_map,
    )
    # # Save the plot.
    plt.savefig(OUT_PATH + f"{node_id}" + ".png", format="png")
    # # Clear the plot for next iteration.
    # # Otherwise it will draw the new graph on top of it.
    plt.clf()

<Figure size 432x288 with 0 Axes>

### Perturbed Graph

In [None]:
# Make another copy of the graph: perturbed_graph.
# Make the perturbations.
# Plot the graph.

## Rough

In [26]:
THRESHOLD = 15
count = 0
for node_id in data:
    if data[node_id]['adj'].sum() // 2 <= THRESHOLD:
        count += 1
print(f"#Graphs: {len(data)}")
print(f"#Graphs with less than {THRESHOLD} nodes: {count}")

#Graphs: 209
#Graphs with less than 15 nodes: 14
