##  1-dimensional Weisfeiler-Lehman algorithm implementation

### Structure:

Input - graph G 

Output - stable colouring of the graph nodes


#### Variables:
- G: graph
- C[v]: list of colours of nodes, initially all set to same colour to 1.
- M: multiset, initially empty
- L: work list, this consists of the new colour classes that are introduced in the latest iteration (initially set to initial colour).
- V: visited list to keep check of nodes when finding nodes with colour 


#### Functions:
- FindXColourNodes: 
    - To find the nodes in the graph with colours X
    - Outputs a list of vertices with colour X.
- FindNeighbourhood: 
    - To find the neighbourhood of the nodes with colour X
    - Outputs a list of vertices which are the neighbours of a particular node with colour X.
- UpdateMultiset: 
    - Updates the tupes from multiple elements for a particular node to a single element for a single node
- RadixSort: 
    - Implement radix sort
- UpdateColours: 
    - Updates the colours of the nodes and adds the new ones in L and set M to NULL

###### -NOTE: Visited is not required i think

#### Steps of the Algorithm:
- Initialise the graph G
- Initialise the colouring list C[v] for each node v. Initially to 1
- Initialise the multiset M to NULL
- Initialise the work list L to initial colouring
- Initialise the visited list V to NULL
- While L is not empty
- Pop the first element from L
- For the colour class c popped, find all the vertices of that colour
- For each vertex w in the colour class, find the neighbourhood of w, if not already visited
- For each vertex v in the neighbourhood of w, add v to the multiset M. Add neighbour node and current colour class, i.e. add(v, c)
- If M is not empty, Sort the multiset M using radix sort
- For each element (v, c) in the sorted multiset M
- Replace the tuples in M from (v, c1), ... (v, cr) to a single tuple (C(v), c1, ...cr, v)
- Apply radix sort on M
- For each element (C(v), c1, ...cr, v) in the sorted multiset, update the colours of the nodes. Leave the largest part to c(prev colour) and update the other to new colours. 
- Update the work list L with the new colour classes
- Make visited list NULL
- Make multiset NULL


In [369]:
%pip install networkx
%pip install multiset


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [404]:
from multiset import *
import networkx as nx
from collections import Counter


In [1086]:
class WL:
    def __init__(self, G, L, M):
        self.G = G
        # self.C = C
        self.M = M
        self.L = L

    def FindXColourNodes(self, G, reqdColour):
        coloursList = []
        print("nodes: ", G.nodes)
        color_attributes = nx.get_node_attributes(G, 'colour')

        for i in G.nodes:
            if int(color_attributes[i]) == reqdColour:
                coloursList.append(i)
                
        return coloursList

    #implement this without using a library 
    def FindNeighbourhood(self, G, node):
        
        neighboutList = list(self.G.neighbors(node))
        # print("neighbourhood: ", i, list(self.G.neighbors(i)))

        return neighboutList
            


    def RadixSort(self, M, keys):
        #sort the multiset which looks something like this: Counter({(3, 1): 1, (1, 1): 1, (2, 1): 1, (4, 1): 1})
        #number of tuples in the list keys
        n = len(keys)
        if n == 0:
            return []

        numberOfVal = len(keys[0]) #values in each tuple in the multiset
        #itr from numberOfVal-1 to 0, we can access each value in each tuple
        #e.g. in ({(3, 1): 1, (1, 1): 1, (2, 1): 1, (4, 1): 1}), say keys[0] = (3, 1) then keys[0][0] = 3 and keys[0][1] = 1
        #so when sorting numberOfVal-1 will be the first place to be sorted

        #to temporarily store the tuples to sort
        bins = [[] for _ in range(10)]

        for i in range(numberOfVal-1, -1, -1):
            for j in range(n):
                idx = keys[j][i]
                bins[idx].append(keys[j])

            self.M.clear()
            for j in range(10):
                for itr in bins[j]:
                    tempTuple = []
                    tempTuple.append(itr)
                    self.M.update(tempTuple)
                bins[j].clear()


    # Helper function to print the Counter without the prefix
    def print_multiset(self, multiset):
        formatted_output = ', '.join([f"{k}: {v}" for k, v in multiset.items()])
        print("Multiset: ", formatted_output)


    def OneWL(self):
        print(self.L, self.M)
        #iterates on the colour classes
        while self.L: #loop 1
            
            for i in self.L: #loop 2
                #get the colour class at the front
                colour = self.L.pop(0)

                #loop to get all the nodes from the graph with colour c
                print("nodes: ", self.G.nodes)
                color_attributes = nx.get_node_attributes(self.G, 'colour') #return a dict with key as nodes and value as their colour
                print("colour attri: ", color_attributes) 

                for i in self.G.nodes: #loop 3
                    if int(color_attributes[i]) == colour:
                        # XColourNodes.append(i)
                        
                        #loop to find the neighbours of the node with colour c
                        neighboutList = list(self.G.neighbors(i)) #inplace of loop 4
                        print("neighbourhood: ", i, neighboutList)
                        for j in neighboutList:
                            tempTuple = []
                            tempTuple.append((j, colour))
                            # print("inner_tempTuple: ", tempTuple)
                            self.M[tempTuple[0]] = 1
                            # if tempTuple[0] not in self.M: #this and the abode line both to check and only insert a single copy of the tuple in the counter
                            #     self.M.update(tempTuple)
                            # print("M: ", self.M)

                        
                        print("tempTuple: ", tempTuple)
                        # self.print_multiset(self.M) # to print the multiset with counter keyword before it
                        print("multiset: ", self.M)



            keys = list(self.M.keys())
            print("some val", [keys[0]])
            #first radix sort
            self.RadixSort(self.M, keys)
            print("new M: ", self.M)
            keys = list(self.M.keys())

            #to update multi-tuple multiset to for a single-tuple multiset 
            print("updated keys: ", keys)
            curNode = keys[0][0]

            colourList = ()
            nodeColour = -1
            print("curNode: ", curNode) 


            j=0
            # self.M.clear()
            newM = Counter()

            for i in range(len(keys)):
                nodeColour = int(self.G.nodes[curNode]['colour'])
                colourList = colourList + (nodeColour, )
                while j<len(keys):
                    if(keys[j][0] == curNode):
                        node = (keys[j][1], )  
                        colourList = colourList + node
                        print("keys[j]: ", keys[j])
                    else:
                        colourList = colourList + (curNode, )
                        curNode = keys[j][0]
                        print("curnode: ", curNode, "node[0]", keys[j][0])
                        break
                    j+=1
                if(j == len(keys)):
                        colourList = colourList + (curNode, )
                        
                print(colourList)
                #now here insert the newly created tuple to the mnultiset 
                #getting error because of some str as list indices
                tempList = []
                tempList.append(colourList)
                newM.update(tempList)
                colourList = tuple()
                if(j==len(keys)):
                    break

            #can't do this to assign the new tuple to the old one or can't even make such changes in the old one as the idx values become problamatic. Try by removing the comment and running
            # self.M = newM

            print("nodes_attri: ", keys[0][0])
            print("node_col: ", type(colourList))
            print("M: ", newM)

        return 



In [1087]:
G = nx.Graph()
G.add_node(1, colour='1')
G.add_node(2, colour='1')
G.add_node(3, colour='7')
G.add_node(4, colour='1')

G.add_edge(1, 2)
G.add_edge(2, 3)
G.add_edge(3, 4)
G.add_edge(4, 1)
G.add_edge(1, 3)


# colours = [1, 1, 1, 1]
M = Counter()
workList = [1, 7]



In [1088]:
# print(list(G.nodes[1][colour]))
print(list(G.nodes[1]['colour']))  # Output: 'red'
print(list(G.edges))
color_attributes = nx.get_node_attributes(G, 'colour')
print(color_attributes) 

['1']
[(1, 2), (1, 4), (1, 3), (2, 3), (3, 4)]
{1: '1', 2: '1', 3: '7', 4: '1'}


In [1089]:
obj = WL(G, workList, M)
obj.OneWL()

[1, 7] Counter()
nodes:  [1, 2, 3, 4]
colour attri:  {1: '1', 2: '1', 3: '7', 4: '1'}
neighbourhood:  1 [2, 4, 3]
tempTuple:  [(3, 1)]
multiset:  Counter({(2, 1): 1, (4, 1): 1, (3, 1): 1})
neighbourhood:  2 [1, 3]
tempTuple:  [(3, 1)]
multiset:  Counter({(2, 1): 1, (4, 1): 1, (3, 1): 1, (1, 1): 1})
neighbourhood:  4 [3, 1]
tempTuple:  [(1, 1)]
multiset:  Counter({(2, 1): 1, (4, 1): 1, (3, 1): 1, (1, 1): 1})
some val [(2, 1)]
new M:  Counter({(1, 1): 1, (2, 1): 1, (3, 1): 1, (4, 1): 1})
updated keys:  [(1, 1), (2, 1), (3, 1), (4, 1)]
curNode:  1
keys[j]:  (1, 1)
curnode:  2 node[0] 2
(1, 1, 1)
keys[j]:  (2, 1)
curnode:  3 node[0] 3
(1, 1, 2)
keys[j]:  (3, 1)
curnode:  4 node[0] 4
(7, 1, 3)
keys[j]:  (4, 1)
(1, 1, 4)
nodes_attri:  1
node_col:  <class 'tuple'>
M:  Counter({(1, 1, 1): 1, (1, 1, 2): 1, (7, 1, 3): 1, (1, 1, 4): 1})
nodes:  [1, 2, 3, 4]
colour attri:  {1: '1', 2: '1', 3: '7', 4: '1'}
neighbourhood:  3 [2, 4, 1]
tempTuple:  [(1, 7)]
multiset:  Counter({(1, 1, 1): 1, (1, 1, 2):

IndexError: tuple index out of range