In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import networkx as nx
import numpy as np
import pandas as pd
import random
import os
import sklearn.cluster 
import plotly.express as px
import community

import msg_passing
import utils
import run
import display
import parse_data
import baselines

In [3]:
def load_results(fdir, f_prefix):
    g = msg_passing.load_graph_graphml(fdir + f_prefix + ".graphml")
    hist, diagnostic_hist = msg_passing.load_history(fdir + f_prefix + ".pkl")
    return g, hist, diagnostic_hist

def view_history(fdir, f_prefix):
    hist, _ = msg_passing.load_history(fdir + f_prefix + ".pkl") 
    return hist.keys()

def plot_run(fdir, f_prefix, target):
    print("Plotting run for", fdir + f_prefix)
    g = msg_passing.load_graph_graphml(fdir + f_prefix + ".graphml")
    hist, diagnostic_hist = msg_passing.load_history(fdir + f_prefix + ".pkl")
    display.plot_diagnostic(diagnostic_hist)
    display.plot_history(hist, target=target)

def summmarize_diagnostic(network_dir, network_names, network_suffix):
    hist_files = []
    graph_files = []
    for n in network_names:
        h = network_dir + n + network_suffix + ".pkl"
        g = network_dir + n + network_suffix + ".graphml"
        hist_files.append(h)
        graph_files.append(g)

    fig = display.plot_diagnostic_multiple_issues(hist_files, graph_files, network_names, show=True)
    #fig.write_image("convergence_plots.pdf", height=450, width=986)
    return fig 

def summarize_histogram(network_dir, network_names, network_suffix):
    graph_files = [network_dir + n + network_suffix + ".graphml" for n in network_names]
    fig = display.plot_cos_dist_histogram_grid(graph_files, network_names, "Cos Dist Histogram", 4, 2)
    return fig 

def summarize_confusion_matrix(network_dir, network_names, network_suffix, top_n_nodes, num_nodes=None, drop_noise=True):
    view_dir = "images/view/"
    for n in network_names:
        filepath = network_dir + n + network_suffix + ".graphml"
        g = msg_passing.load_graph_graphml(filepath) 
        pg, _ = msg_passing.prune_graph(g) 

        if(not num_nodes):
            num_nodes = len(pg.nodes())
        nodes = utils.get_top_n_nodes(pg, num_nodes)
        fig = display.plot_confusion_matrix_with_random_baseline(pg, nodes, top_n_nodes, drop_noise=True, title="distance matrix")
        fig.write_html(view_dir + n + ".html")

In [12]:
infiles = [
    "global_warming_network",
    "gun_regulations_network",
    "immigration_network",
    "inflation_network",
    "roe_v_wade_network",
    "trump_impeachment_network",
    "ukraine_war_network",
    "vaccine_hesitancy_network",
    "combined"
]

infiles = [
    "global_warming",
    "gun_regulations",
    "immigration",
    "recession_fears",
    "roe_v_wade",
    "ukraine_war",
    "vaccine_hesitancy",
]

network_names = [ 
    "global warming",
    "gun regulations",
    "immigration",
    "recession fears",
    "roe v. wade",
    "ukraine war",
    "vaccine hesitancy",
]

indir = "input/Networks/"

#outdir = "output/Incremental_Datasets_2/"
#outdir = "output/archive/"

outdir = "output/Networks_v1/"
#outdir = "output/random_edges/"
outdir = "output/with_windows/"
#outdir = "output/degree_penalty/"
#outdir = "output/correctly_weighted/"
#outdir = "output/degree_penalty_correctly_weighted/"

in_suffix = "_network.csv"

#out_suffix = "_random_walks_lr_10-3_30K_dc_095_pl_10_bs_10"
#out_suffix = "_v3_random_walk_batch_10_path_10_50K_lr_3_dim_3"

out_suffix = "_network_random_walks_lr_10-3_20K_dc_095_pl_{path_length}_bs_10"
#out_suffix = "_network_random_edge_baseline_random_walks_lr_10-3_20K_dc_095_pl_10_bs_10" 
out_suffix = "_network_random_walks_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_degree_penalty_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_correctly_weighted_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_degree_penalty_correctly_weighted_lr_10-3_20K_dc_095_pl_10_bs_10"

net = infiles[5] 
g, hist, diagnostic_hist = load_results(outdir, net + out_suffix)


#print(hist.keys())
#display.plot_history_with_reference(hist, "anthony fauci")
#display.plot_top_n_cluster_evaluations(g, [20, 50, 100], 2, 10)
#print(view_history(outdir, net + out_suffix))
#plot_run(outdir, net + out_suffix, "biden")
#_ = display.plot_edge_weight_histogram(g, log_scale=False)
#_ = display.plot_degree_histogram(g, log_scale=True)

#fig = display.plot_confusion_matrix(g, utils.get_top_n_nodes(g, 20), 2, title=net)
#_ = display.plot_cos_dist_histogram(g, title=net)
#fig.write_html("images/vaccine_heatmap.html")



#_ = summmarize_diagnostic(outdir, infiles, out_suffix)
#_ = summarize_histogram(outdir, infiles, out_suffix)
#summarize_confusion_matrix(outdir, infiles, out_suffix, 20, num_nodes=50, drop_noise=True)
#_ = display.plot_diagnostic_grid(hist_files, graph_files, infiles, "Update Magnitude and Loss", 4, 2)


g = msg_passing.load_graph_graphml(outdir + infiles[0] + out_suffix + ".graphml")
ag, _ = msg_passing.prune_graph(utils.largest_connected_component(g))
g = msg_passing.permute_edges(g) 
g = utils.largest_connected_component(g)
pg, _ = msg_passing.prune_graph(g)
msg_passing.save_graph(pg, "output/test.graphml")
msg_passing.save_graph(ag, "output/test2.graphml")

"""
write_dir = "output/pruned_graphs/"
for n in infiles:
    print(n)
    fname = outdir + n + out_suffix + ".graphml"
    #fname = indir + n + in_suffix
    #g = msg_passing.load_graph_csv(fname, clean_data=True)
    #msg_passing.save_graph(g, write_dir + n + "_full.graphml")
    g = msg_passing.load_graph_graphml(fname)
    g = utils.largest_connected_component(g)
    pg, _ = msg_passing.prune_graph(g)
    msg_passing.save_graph(pg, write_dir + n + ".graphml")
"""

'\nwrite_dir = "output/pruned_graphs/"\nfor n in infiles:\n    print(n)\n    fname = outdir + n + out_suffix + ".graphml"\n    #fname = indir + n + in_suffix\n    #g = msg_passing.load_graph_csv(fname, clean_data=True)\n    #msg_passing.save_graph(g, write_dir + n + "_full.graphml")\n    g = msg_passing.load_graph_graphml(fname)\n    g = utils.largest_connected_component(g)\n    pg, _ = msg_passing.prune_graph(g)\n    msg_passing.save_graph(pg, write_dir + n + ".graphml")\n'

In [4]:
def get_shortest_path_network(g, nodes1, nodes2):
    subgraph_nodes = set()
    for n1 in nodes1:
        for n2 in nodes2:
            sps = list(nx.all_shortest_paths(g, n1, n2))
            for sp in sps:
                for n in sp:
                    subgraph_nodes.add(n) 
    
    return g.subgraph(subgraph_nodes)

gf = "output/Networks_v1/gun_regulations_network_random_walks_lr_10-3_20K_dc_095_pl_{path_length}_bs_10.graphml"
#gf = "output/Networks_v1/roe_v_wade_network_random_walks_lr_10-3_20K_dc_095_pl_{path_length}_bs_10.graphml"
gf = "output/Networks_v1/recession_fears_network_random_walks_lr_10-3_20K_dc_095_pl_{path_length}_bs_10.graphml"
g = msg_passing.load_graph_graphml(gf)
pg, _ = msg_passing.prune_graph(g) 
n1 = ["second amendment", "bruen", "greg abbott", "gerald smith", "iowa firearms coalition", "supreme court"]
#n2 = ["second amendment foundation", "adam kraut"]
n2 = ["joe biden", "kamala harris", "a ban on assault weapons", "biden", "white house"]
#n1 = ["donald trump", "white house", "supreme court", "clarence thomas"]
#n2 = ["republican", "republicans", "gop", "arizona", "planned parenthood", "anti-abortion"]
n1 = ["opec", "fox news", "saudis"]
n2 = ["jerome powell", "joe biden", "janet yellen"]

sg = get_shortest_path_network(pg, n1, n2)
write_dir = "output/cyto_pruned_networks/"
outname = write_dir + "recession_1.graphml" 
msg_passing.save_graph(sg, outname)

In [39]:
def cluster_and_compute_silhouette_score(g, nodes):
    labels = parse_data.cluster_nodes_hdbscan(g, nodes) 
    score = parse_data.evaluate_clustering(g, nodes, labels)
    return labels, score

def cluster_and_evaluate_issues(network_dir, network_names, network_suffix, num_nodes=None):
    for name in network_names:
        filepath = network_dir + name + network_suffix
        g = msg_passing.load_graph_graphml(filepath + ".graphml") 
        pg, _ = msg_passing.prune_graph(g)
        
        # manually change
        if(not num_nodes):
            total_nodes = len(pg.nodes())
            num_nodes = total_nodes
        nodes = utils.get_top_n_nodes(pg, num_nodes)
        labels, score = cluster_and_compute_silhouette_score(pg, nodes) 

        num_clusters = len(set(labels))
        print(f"Results for {name}: total nodes: {total_nodes}, num clusters: {num_clusters}, score: {score}")
        
infiles = [
    "gun_regulations",
    "immigration",
    "recession_fears",
    "roe_v_wade",
    "ukraine_war",
    "vaccine_hesitancy",
]
outdir = "output/with_windows/"
#outdir = "output/random_permutation/"
#outdir = "output/random_edges/"
#outdir = "output/random_weights/"
#outdir = "output/degree_penalty_correctly_weighted/"
out_suffix = "_network_random_walks_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_random_permutation_baseline_random_walks_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_random_edge_baseline_random_walks_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_random_weight_baseline_random_walks_lr_10-3_20K_dc_095_pl_10_bs_10"
#out_suffix = "_network_degree_penalty_correctly_weighted_lr_10-3_20K_dc_095_pl_10_bs_10"
#cluster_and_evaluate_issues(outdir, infiles, out_suffix)
#cluster_and_evaluate_issues_hdbscan(outdir, infiles, out_suffix)
#cluster_and_evaluate_issues_louvain(outdir, infiles, out_suffix)
#baselines.evaluate_multiple_issues_louvain(outdir, infiles, out_suffix, num_nodes=None, prune=True)
baselines.evaluate_multiple_issues_hdbscan(outdir, infiles, out_suffix, num_nodes=None, prune=True)

# TODO: plot degree distribution of permuted graphs



HDBSCAN clustering results for gun_regulations - num clusters: 13, score: -0.040456127676587926
HDBSCAN clustering results for immigration - num clusters: 10, score: -0.02081121038984716
HDBSCAN clustering results for recession_fears - num clusters: 6, score: -0.03717756905965803
HDBSCAN clustering results for roe_v_wade - num clusters: 6, score: -0.019887985789492788
HDBSCAN clustering results for ukraine_war - num clusters: 13, score: -0.02495503297768776
HDBSCAN clustering results for vaccine_hesitancy - num clusters: 4, score: -0.015456129509601893
