<a href="https://colab.research.google.com/github/aaririri/task8_KrackhardtKite/blob/main/task8_A.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch-geometric



In [None]:
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx
import networkx as nx
import collections

def weisfeiler_lehman_hash(graph):
    colors = {node: data.get('label', 0) for node, data in graph.nodes(data=True)}

    for _ in range(len(graph.nodes())): # Iterate enough times for convergence
        new_colors = {}
        for node in graph.nodes():
            neighbor_colors = sorted([colors[nbr] for nbr in graph.neighbors(node)])
            signature = (colors[node], tuple(neighbor_colors))
            new_colors[node] = hash(signature) # Hash the signature to get a new color

        if new_colors == colors:
            break
        colors = new_colors

    canonical_hash = str(sorted(colors.values()))
    return canonical_hash

def find_isomorphic_groups_from_pyg(dataset):
    hashes = collections.defaultdict(list)

    for i, data in enumerate(dataset):
        g = to_networkx(data, node_attrs=['x'])

        nx.set_node_attributes(g, {j: int(x[0]) for j, x in enumerate(data.x)}, 'label')

        if len(g) > 0:
            h = weisfeiler_lehman_hash(g)
            hashes[h].append(i)

    isomorphic_groups = {h: indices for h, indices in hashes.items() if len(indices) > 1}
    return isomorphic_groups


In [22]:
import random
import copy

def perturb_graph_add_edge_or_remove(graph, graph_id=None, copy_id=None):
    g = copy.deepcopy(graph)
    nodes = list(g.nodes())
    if random.choice(["add", "remove"]) == "add":
        possible_edges = [(u, v) for i, u in enumerate(nodes) for v in nodes[i+1:]
                          if not g.has_edge(u, v)]
        if not possible_edges:
            if graph_id is not None:
                print(f"Graph {graph_id} Copy {copy_id}: No edge could be added (graph already complete).")
            return g
        edge_to_add = random.choice(possible_edges)
        g.add_edge(*edge_to_add)
        if graph_id is not None:
            print(f"Graph {graph_id} Copy {copy_id}: Added edge {edge_to_add}")
    else:
        if g.number_of_edges() == 0:
            if graph_id is not None:
                print(f"Graph {graph_id} Copy {copy_id}: No edge could be removed (graph has no edges).")
            return g
        edge_to_remove = random.choice(list(g.edges()))
        g.remove_edge(*edge_to_remove)
        if graph_id is not None:
            print(f"Graph {graph_id} Copy {copy_id}: Removed edge {edge_to_remove}")
    return g


def generate_perturbed_collection_from_nx(nx_graphs, largest_group, num_copies=10):
    all_graphs = []
    for idx, g in zip(largest_group, nx_graphs):
        all_graphs.append((f"orig-{idx}", g))
        for copy_id in range(1, num_copies + 1):
            pert_g = perturb_graph_add_edge_or_remove(g, graph_id=idx, copy_id=copy_id)
            all_graphs.append((f"pert-{idx}-{copy_id}", pert_g))
    print(f"\nGenerated {len(all_graphs)} graphs "
          f"({len(nx_graphs)} originals + {len(nx_graphs) * num_copies} perturbed).")
    return all_graphs


def run_wl_on_collection(graphs):
    hashes = collections.defaultdict(list)
    for graph_id, g in graphs:
        if len(g) > 0:
            h = weisfeiler_lehman_hash(g)
            hashes[h].append(graph_id)
    isomorphic_groups = {h: ids for h, ids in hashes.items() if len(ids) > 1}
    print("\nWL Test Results After Perturbation:")
    print(f"  - Total groups found: {len(isomorphic_groups)}")
    total_graphs = sum(len(ids) for ids in isomorphic_groups.values())
    print(f"  - Total graphs inside groups: {total_graphs}")
    if isomorphic_groups:
        print("\n  Group details:")
        for i, ids in enumerate(isomorphic_groups.values(), 1):
            print(f"    Group {i}: {ids}")
    return isomorphic_groups


In [None]:

if __name__ == "__main__":
    print("1. Loading AIDS dataset...")
    dataset = TUDataset(root='/tmp/AIDS', name='AIDS', use_node_attr=True)

    print("2. Finding isomorphic groups...")
    isomorphic_groups = find_isomorphic_groups_from_pyg(dataset)

    if not isomorphic_groups:
        print("   No isomorphic groups found.")
    else:
        num_isomorphic_groups = len(isomorphic_groups)
        total_graphs_in_groups = sum(len(indices) for indices in isomorphic_groups.values())
        print(f"   - Total Isomorphic Groups Found: {num_isomorphic_groups}")
        print(f"   - Total Graphs in Isomorphic Groups: {total_graphs_in_groups}")

        print("\n   Isomorphic Group Details (Graph Indices):")
        sorted_groups = sorted(isomorphic_groups.values(), key=lambda x: x[0])
        for i, group in enumerate(sorted_groups):
            print(f"     - Group {i+1}: {group}")


1. Loading AIDS dataset...
2. Finding isomorphic groups...
   - Total Isomorphic Groups Found: 84
   - Total Graphs in Isomorphic Groups: 182

   Isomorphic Group Details (Graph Indices):
     - Group 1: [18, 1832]
     - Group 2: [23, 1950]
     - Group 3: [26, 1184]
     - Group 4: [27, 666]
     - Group 5: [48, 877]
     - Group 6: [60, 1685]
     - Group 7: [63, 751]
     - Group 8: [78, 123]
     - Group 9: [90, 711]
     - Group 10: [120, 1825]
     - Group 11: [125, 1505]
     - Group 12: [127, 700]
     - Group 13: [138, 234]
     - Group 14: [142, 1909]
     - Group 15: [153, 1617]
     - Group 16: [160, 213]
     - Group 17: [187, 603]
     - Group 18: [189, 1004, 1711]
     - Group 19: [197, 951]
     - Group 20: [218, 239]
     - Group 21: [219, 1071]
     - Group 22: [230, 574]
     - Group 23: [232, 878]
     - Group 24: [235, 812]
     - Group 25: [244, 1196]
     - Group 26: [257, 856]
     - Group 27: [262, 509]
     - Group 28: [277, 373]
     - Group 29: [280, 976]
 

In [26]:
largest_group = max(isomorphic_groups.values(), key=len)
print(f"Largest isomorphic group has {len(largest_group)} graphs: {largest_group}")

nx_graphs = []
for idx in largest_group:
    data = dataset[idx]
    g = to_networkx(data, node_attrs=['x'])
    nx.set_node_attributes(g, {j: int(x[0]) for j, x in enumerate(data.x)}, 'label')
    nx_graphs.append(g)

print("\nGenerating perturbed copies...")
all_graphs = generate_perturbed_collection_from_nx(nx_graphs, largest_group, num_copies=10)
print(f"Generated {len(all_graphs)} graphs in total "
      f"({len(nx_graphs)} originals + {len(nx_graphs) * 10} perturbed).")

perturbed_groups = run_wl_on_collection(all_graphs)



Largest isomorphic group has 8 graphs: [709, 1270, 1454, 1585, 1607, 1809, 1964, 1977]

Generating perturbed copies...
Graph 709 Copy 1: Removed edge (18, 23)
Graph 709 Copy 2: Removed edge (13, 12)
Graph 709 Copy 3: Removed edge (3, 25)
Graph 709 Copy 4: Removed edge (20, 22)
Graph 709 Copy 5: Added edge (11, 14)
Graph 709 Copy 6: Added edge (16, 18)
Graph 709 Copy 7: Added edge (0, 2)
Graph 709 Copy 8: Added edge (13, 15)
Graph 709 Copy 9: Added edge (8, 16)
Graph 709 Copy 10: Removed edge (3, 25)
Graph 1270 Copy 1: Removed edge (3, 2)
Graph 1270 Copy 2: Removed edge (0, 9)
Graph 1270 Copy 3: Added edge (21, 25)
Graph 1270 Copy 4: Removed edge (5, 6)
Graph 1270 Copy 5: Removed edge (17, 16)
Graph 1270 Copy 6: Removed edge (9, 8)
Graph 1270 Copy 7: Removed edge (6, 7)
Graph 1270 Copy 8: Removed edge (0, 9)
Graph 1270 Copy 9: Removed edge (8, 7)
Graph 1270 Copy 10: Removed edge (12, 11)
Graph 1454 Copy 1: Removed edge (5, 4)
Graph 1454 Copy 2: Removed edge (26, 25)
Graph 1454 Copy 3: R

In [27]:
print("\nChecking original–perturbed isomorphism preservation...")

group_lookup = {}
for group in perturbed_groups.values():
    for gid in group:
        group_lookup[gid] = group

broken_count = 0
total_originals = len(largest_group)

for orig_idx in largest_group:
    orig_id = f"orig-{orig_idx}"
    pert_ids = [f"pert-{orig_idx}-{copy_id}" for copy_id in range(1, 11)]

    group = group_lookup.get(orig_id, None)

    if group is None:
        broken_count += 1
        print(f"  Original {orig_id}: no longer in any isomorphic group.")
    else:
        if not all(pid in group for pid in pert_ids):
            broken_count += 1
            missing = [pid for pid in pert_ids if pid not in group]
            print(f"  Original {orig_id}: broke isomorphism with {len(missing)} perturbed copies.")

print(f"\nNumber of original–perturbed relationships broken: {broken_count} out of {total_originals} originals.")



Checking original–perturbed isomorphism preservation...
  Original orig-709: broke isomorphism with 10 perturbed copies.
  Original orig-1270: broke isomorphism with 10 perturbed copies.
  Original orig-1454: broke isomorphism with 10 perturbed copies.
  Original orig-1585: broke isomorphism with 10 perturbed copies.
  Original orig-1607: broke isomorphism with 10 perturbed copies.
  Original orig-1809: broke isomorphism with 10 perturbed copies.
  Original orig-1964: broke isomorphism with 10 perturbed copies.
  Original orig-1977: broke isomorphism with 10 perturbed copies.

Number of original–perturbed relationships broken: 8 out of 8 originals.


In [28]:
print(f"\nSummary:")
print(f"  - Broken original–perturbed relationships: {broken_count} out of {total_originals}")
print(f"  - Preserved (still isomorphic): {preserved_count} out of {total_originals}")



Summary:
  - Broken original–perturbed relationships: 8 out of 8
  - Preserved (still isomorphic): 0 out of 8


#All isomorphic relationships break because WL detects even a single edge addition/removal as a structural change. The new graph has a different neighborhood coloring sequence, leading to a different WL hash. Hence, none of the perturbed graphs stay in the same isomorphic group as the originals.