In [1]:
import pandas as pd
import networkx as nx
import pickle as pkl
import matplotlib.pyplot as plt
from methods import find_groups

# First create a nx.Graph() object here. This example uses the graph from the Chemistry dataset. Replace this code to generate your own.

In [2]:
"""Protein numbers for:
-> `OGBG-MolHIV_visualization.ipynb`.
182987
135257
200289
104795

-> Default: 349519
"""

PROTEIN_NR_MOLPCBA = 104795

pkl_filename = f"graph_{PROTEIN_NR_MOLPCBA}_OGBG-MolHIV.pkl"

with open(pkl_filename, 'rb') as pkl_input:
    graph = pkl.load(pkl_input)

edge_index = graph['edge_index']
num_nodes = graph['num_nodes']

G = nx.Graph()

G.add_nodes_from(range(num_nodes))

edges = list(zip(edge_index[0], edge_index[1]))
G.add_edges_from(edges)

# Create a LU table
Create a look up table with numerical columns that represent attributes, and a column 'target' with the target variable (binary). Please use the names of the nodes in the graph as index for the lookup table.

In [3]:
attributes = graph['node_feat']
lu = pd.DataFrame(attributes)
lu['target'] = lu[0] >= 6

In [4]:
result = find_groups(G, 20, lu, ablation_mode=False, use_multiprocessing=False)

100%|██████████| 266/266 [00:58<00:00,  4.56it/s]


In [5]:
result

Unnamed: 0_level_0,rho,sigma,q,ranks,reference,subgroup
node,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
89,11,5,0.232905,"[(89, True), (87, False), (90, False), (94, Fa...","[89, 87, 90, 94, 85, 88, 91, 84, 96, 95, 93]","[89, 87, 90, 94, 85]"
56,16,10,0.207524,"[(56, True), (57, True), (55, False), (58, Fal...","[56, 57, 55, 58, 54, 96, 52, 59, 67, 94, 53, 5...","[56, 57, 55, 58, 54, 96, 52, 59, 67, 94]"
113,11,7,0.207201,"[(113, True), (111, False), (114, False), (117...","[113, 111, 114, 117, 112, 115, 106, 105, 119, ...","[113, 111, 114, 117, 112, 115, 106]"
199,11,7,0.207201,"[(199, True), (197, False), (200, False), (201...","[199, 197, 200, 201, 198, 254, 196, 203, 192, ...","[199, 197, 200, 201, 198, 254, 196]"
219,11,7,0.207201,"[(219, True), (217, False), (220, False), (223...","[219, 217, 220, 223, 218, 216, 221, 215, 225, ...","[219, 217, 220, 223, 218, 216, 221]"
167,21,12,0.206981,"[(167, False), (169, True), (168, True), (162,...","[167, 169, 168, 162, 161, 163, 170, 159, 164, ...","[167, 169, 168, 162, 161, 163, 170, 159, 164, ..."
55,5,3,0.206559,"[(55, False), (54, False), (56, True), (96, Tr...","[55, 54, 56, 96, 57]","[55, 54, 56]"
29,5,3,0.206559,"[(29, True), (30, False), (28, False), (31, Tr...","[29, 30, 28, 31, 32]","[29, 30, 28]"
140,9,4,0.203704,"[(140, True), (138, False), (141, False), (142...","[140, 138, 141, 142, 139, 137, 136, 144, 143]","[140, 138, 141, 142]"
96,15,5,0.19245,"[(96, True), (94, False), (54, False), (52, Fa...","[96, 94, 54, 52, 55, 95, 90, 89, 51, 53, 56, 9...","[96, 94, 54, 52, 55]"


In [6]:
output = pd.DataFrame()
output['Prototype'] = result.index
output.index = result.index
output['Rho'] = result['rho']
output['Sigma'] = result['sigma']
output['Q'] = result['q']
output['Ranks'] = result['ranks']
output.index.name = None
output[0:5].to_latex()

'\\begin{tabular}{lrrrrl}\n\\toprule\n & Prototype & Rho & Sigma & Q & Ranks \\\\\n\\midrule\n89 & 89 & 11 & 5 & 0.232905 & [(89, True), (87, False), (90, False), (94, False), (85, False), (88, True), (91, False), (84, True), (96, True), (95, True), (93, True), (86, False), (92, False), (82, False), (54, False), (52, False), (55, False), (83, True), (78, False), (77, True), (51, True), (53, True), (56, True), (79, False), (75, False), (81, True), (57, True), (80, False), (50, False), (48, False), (58, False), (76, True), (70, False), (97, False), (69, True), (47, True), (49, True), (59, False), (71, False), (98, False), (99, False), (60, True), (67, False), (72, False), (46, False), (73, True), (44, False), (61, False), (68, True), (74, True), (100, False), (43, True), (45, True), (62, True), (102, True), (101, False), (63, False), (64, True), (36, False), (65, False), (34, False), (66, True), (37, False), (33, True), (38, False), (35, True), (42, True), (39, False), (25, False), (41, 

In [7]:
csv_filename = f'subgroups_graph_{PROTEIN_NR_MOLPCBA}_OGBG-MolHIV.csv'
output.to_csv(csv_filename)

In [None]:
lu['color'] = 'Skyblue'
lu.loc[lu['target'], 'color'] = 'Red'

In [None]:
# CHANGE THIS TO MAKE THE SUBGROUPS VISIBLE
def change_value(rho_val, prototype):
    ranks = result.loc[prototype, 'ranks']
    reference_group = [x[0] for x in ranks[0:int(rho_val)]]

    lu.loc[reference_group, 'color'] = 'Green'
    return None

In [None]:
lu['color'] = 'gray'
lu['size'] = 500

proto = 46
rho = output.loc[proto, 'Rho']
sigma = output.loc[proto, 'Sigma']
change_value(rho, proto)

In [None]:
plt.figure(figsize=(10, 10))
nx.draw(G, with_labels=True, node_color=lu['color'], edge_color='gray', node_size=lu['size'], font_size=10)
plt.title("Network Graph")
plt.show()

In [None]:
result_ablation = find_groups(G, 20, lu, ablation_mode=True)
result_ablation

In [None]:
from ablation_metrics import compare_subgroups

comparison = compare_subgroups(result, result_ablation, k=10)
comparison

In [None]:
from ablation_metrics import evaluate_pattern_subgroups

mean_ratio_normal, dist_ratios_normal = evaluate_pattern_subgroups(result, lu, k=10)
mean_ratio_ablation, dist_ratios_ablation = evaluate_pattern_subgroups(result_ablation, lu, k=10)

print("Avg fraction of target=1 in top-10 normal subgroups: ", mean_ratio_normal)
print("Avg fraction of target=1 in top-10 ablation subgroups:", mean_ratio_ablation)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8,5))
plt.hist(result['q'], bins=30, alpha=0.5, label='Normal')
plt.hist(result_ablation['q'], bins=30, alpha=0.5, label='Ablation')
plt.xlabel('Q Score')
plt.ylabel('Frequency')
plt.title('Distribution of Q in Normal vs. Ablation')
plt.legend()
plt.show()