## 1-dimensional Weisfeiler-Lehman algorithm implementation


### Structure:

Input - graph G

Output - stable colouring of the graph nodes

#### Variables:

-   G: graph
-   C[v]: list of colours of nodes, initially all set to same colour to 1.
-   M: multiset, initially empty
-   L: work list, this consists of the new colour classes that are introduced in the latest iteration (initially set to initial colour).

#### Functions:

-   FindNeighbourhood:
    -   To find the neighbourhood of the nodes with colour X
    -   Outputs a list of vertices which are the neighbours of a particular node with colour X.
-   UpdateMultiset:
    -   Updates the tupes from multiple elements for a particular node to a single element for a single node
-   RadixSort:
    -   Implement radix sort
-   UpdateColours:
    -   Updates the colours of the nodes and adds the new ones in L and set M to NULL

###### -NOTE: Visited is not required i think

#### Steps of the Algorithm:

-   Initialise the graph G
-   Initialise the colouring list C[v] for each node v. Initially to 1
-   Initialise the multiset M to NULL
-   Initialise the work list L to initial colouring
-   Initialise the visited list V to NULL
-   While L is not empty
-   Pop the first element from L
-   For the colour class c popped, find all the vertices of that colour
-   For each vertex w in the colour class, find the neighbourhood of w, if not already visited
-   For each vertex v in the neighbourhood of w, add v to the multiset M. Add neighbour node and current colour class, i.e. add(v, c)
-   If M is not empty, Sort the multiset M using radix sort
-   For each element (v, c) in the sorted multiset M
-   Replace the tuples in M from (v, c1), ... (v, cr) to a single tuple (C(v), c1, ...cr, v)
-   Apply radix sort on M
-   For each element (C(v), c1, ...cr, v) in the sorted multiset, update the colours of the nodes. Leave the largest part to c(prev colour) and update the other to new colours.
-   Update the work list L with the new colour classes
-   Make multiset NULL


In [1]:
%pip install networkx
%pip install multiset


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
from multiset import *
import networkx as nx
from collections import Counter


In [24]:
class WL:
    def __init__(self, G, L, M):
        self.G = G
        self.M = M
        self.L = L
            

    def StuffTuples(self, M):
        maxLen=-1

        for i in M:
            maxLen = max(maxLen, len(i))
            # print("len(i): ", len(i), "maxLen: ", maxLen)

        M_padded = []
        for i in M:
            if len(i) < maxLen:
                padded_list = list(i) + [0] * (maxLen - len(i))
                M_padded.append(tuple(padded_list))
            else:
                M_padded.append(i)
        
        self.M = M_padded.copy()
        # print(self.M)


    def RadixSort(self, M):
        #sort the multiset which looks something like this: Counter({(3, 1): 1, (1, 1): 1, (2, 1): 1, (4, 1): 1})
        #number of tuples in the list keys
        n = len(self.M)
        if n == 0:
            return []
        
        #function call for stuffing if required
        self.StuffTuples(M)
        print("Stuffed M: ", self.M)

        numberOfVal = len(self.M[0]) #values in each tuple in the multiset
        #itr from numberOfVal-1 to 0, we can access each value in each tuple. e.g. in ({(3, 1): 1, (1, 1): 1, (2, 1): 1, (4, 1): 1}), say keys[0] = (3, 1) then keys[0][0] = 3 and keys[0][1] = 1. so when sorting numberOfVal-1 will be the first place to be sorted

        #to temporarily store the tuples to sort
        bins = [[] for _ in range(100)]

        for i in range(numberOfVal-1, -1, -1):
            for j in range(n):
                
                print("Idx: ", "M[j][i]: ", self.M[j][i], "i: ", i, "j: ", j)
                idx = self.M[j][i]
                bins[idx].append(self.M[j])
                # print(bins)

            self.M.clear()
            for j in range(10):
                for itr in bins[j]:
                    tempTuple = (itr)
                    self.M.append(tempTuple)
                    # var = tempTuple[0]
                    # print("vals[tempTuple]: ", vals[var], tempTuple[0])
                    # self.M[tempTuple[0]] = vals[tempTuple[0]]
                bins[j].clear()

        return self.M


    def PreprocessTuples2(self):
        tupleLen = len(self.M)
        newList = []
        for j in range(len(self.M)):
            tempTuple = ()
            for i in self.M[j]:
                if i != 0:
                    tempTuple = tempTuple + (i, )
            # print("templist: ", tempTuple)
            newList.append(tempTuple)

        self.M = newList
        # print("processed: ", self.M)


    def OneWL(self):
        #iterates on the colour classes
        newColour = self.L[-1]
        colourChangeCheck = [] #stores the nodes whose colours were changed in the previous iterartion
        print("newColour: ", self.L)
        while self.L: #loop 1
            print("----------------START of WorkList----------------")
            print(" L and M: ", self.L, self.M)

            # test=0
            LSize = len(self.L)
            for i in range(LSize): #loop 2
                # print("i: ", i, test, len(self.L))
                # test = 1
                #get the colour class at the front
                print("L size before pop: ", LSize)
                colour = self.L.pop(0)

                #loop to get all the nodes from the graph with colour c
                print("nodes: ", self.G.nodes)
                color_attributes = nx.get_node_attributes(self.G, 'colour') #return a dict with key as nodes and value as their colour
                print("colour attri: ", color_attributes) 

                for i in self.G.nodes: #loop 3
                    if int(color_attributes[i]) == colour:
                        # XColourNodes.append(i)
                        
                        #loop to find the neighbours of the node with colour c
                        neighboutList = list(self.G.neighbors(i)) #inplace of loop 4
                        # print("neighbourhood: ", i, neighboutList)
                        for j in neighboutList:
                            tempTuple = (j, colour)
                            # print("inner_tempTuple: ", tempTuple)
                            # self.M[tempTuple[0]] = 1
                            # if tempTuple[0] not in self.M: #this and the abode line both to check and only insert a single copy of the tuple in the counter
                            self.M.append(tempTuple)
                        # print("M: ", self.M)

                print("After adding to multiset: ", self.M)


            #first radix sort
            print("------------------First Radix sort started-------------------")
            print("M before sorting: ", self.M)
            self.RadixSort(self.M)        
            print("M after sorting: ", self.M)
            print("------------------First Radix sort done-------------------")



            # TupleTransformation(self.M)
            #to update multi-tuple multiset to for a single-tuple multiset 
            keys = self.M.copy()
            # print("M -> Keys: ", keys)
            # print("M: ", self.M)
            curNode = self.M[0][0]

            colourList = ()
            nodeColour = -1
            # print("curNode: ", curNode) 


            j=0
            self.M.clear()
            for i in range(len(keys)):
                nodeColour = int(self.G.nodes[curNode]['colour'])
                colourList = colourList + (nodeColour, )
                while j<len(keys):
                    if(keys[j][0] == curNode):
                        node = (keys[j][1], )  
                        colourList = colourList + node
                        # print("keys[j]: ", keys[j])
                    else:
                        colourList = colourList + (curNode, )
                        curNode = keys[j][0]
                        # print("curnode: ", curNode, "node[0]", keys[j][0])
                        break
                    j+=1
                if(j == len(keys)):
                        colourList = colourList + (curNode, )
                        
                # print("colourlist: ", colourList)
                #now here insert the newly created tuple to the multiset 
                #getting error because of some str as list indices
                self.M.append(colourList)
                # print("Updated M: ", self.M)
                colourList = tuple()
                if(j==len(keys)):
                    break

            print("M after transformation: ", self.M)
            #can't do this to assign the new tuple to the old one or can't even make such changes in the old one as the idx values become problamatic. Try by removing the comment and running

            print("------------------Second Radix sort started-------------------")
            print("M before sorting: ", self.M)
            #How to apply radix sort on this because this sorting is done on the basis of digit place. But here it is possible that the tuple might have unequal number of elements in it and even if we consider 0 like we do in integer values(normally) then the resultant will change. So, we can't do that over here. What to do? ask sir!
            #second radix sort
            self.RadixSort(self.M)
            print("M after sorting: ", self.M)
            print("------------------Second Radix sort done-------------------")

            print("before processing: ", self.M)
            self.PreprocessTuples2()
            print("post-processing: ", self.M)

            tupleCount = {}
            maxCount=-1
            for i in self.M:
                node = i[-1]
                subTuple = i[:-1]
                count=0
                # print(i[:-1], node)
                if subTuple in tupleCount:
                    count = tupleCount[subTuple]['count'] + 1
                    tupleCount[subTuple]['count'] = count
                    tupleCount[subTuple]['nodes'].append(node)
                else:
                    tupleCount[subTuple] = {
                    'count': 1,
                    'nodes': [node]
                    }
                    count=1
                maxCount = max(maxCount, count)
            print("tupleCount: ", tupleCount)

            subTuples = tupleCount.keys()
            print("subTuples of keys: ", subTuples)

            newColourNodes = []
            multipleSameMaxCount = False
            for i in subTuples:
                if tupleCount[i]['count'] != maxCount:
                    curNodes = tupleCount[i]['nodes']
                    for j in curNodes:
                        self.G.nodes[j]['colour'] = newColour+1
                        newColourNodes.append(j)
                    newColour = newColour +1
                    self.L.append(newColour)
                else:
                    if multipleSameMaxCount:
                        curNodes = tupleCount[i]['nodes']
                        for j in curNodes:
                            self.G.nodes[j]['colour'] = newColour+1
                            newColourNodes.append(j)
                        newColour = newColour +1
                        self.L.append(newColour)
                    else:
                        multipleSameMaxCount = True

            if colourChangeCheck == newColourNodes or len(self.L)==0:
                print("no change in colour")
                self.L.clear()
            else:
                print("change in colour")
                colourChangeCheck = newColourNodes
                        
            print("Final print: ", self.M)
            self.M.clear()
            for j in self.G.nodes:
                print("node: ", j, "colour: ", self.G.nodes[j]['colour'])


        return 
    



In [25]:
G = nx.Graph()
G.add_node(1, colour='1')
G.add_node(2, colour='1')
G.add_node(3, colour='1')
G.add_node(4, colour='1')
G.add_node(5, colour='1')

G.add_edge(1, 2)
G.add_edge(2, 3)
G.add_edge(3, 4)
G.add_edge(4, 1)
G.add_edge(1, 3)
G.add_edge(1, 5)
G.add_edge(2, 5)


M = []
workList = [1]



In [26]:
# print(list(G.nodes[1][colour]))
print(list(G.nodes[1]['colour']))  # Output: 'red'
print(list(G.edges))
color_attributes = nx.get_node_attributes(G, 'colour')
print(color_attributes) 
print(workList)

['1']
[(1, 2), (1, 4), (1, 3), (1, 5), (2, 3), (2, 5), (3, 4)]
{1: '1', 2: '1', 3: '1', 4: '1', 5: '1'}
[1]


In [27]:
obj = WL(G, workList, M)
obj.OneWL()

newColour:  [1]
----------------START of WorkList----------------
 L and M:  [1] []
L size before pop:  1
nodes:  [1, 2, 3, 4, 5]
colour attri:  {1: '1', 2: '1', 3: '1', 4: '1', 5: '1'}
After adding to multiset:  [(2, 1), (4, 1), (3, 1), (5, 1), (1, 1), (3, 1), (5, 1), (2, 1), (4, 1), (1, 1), (3, 1), (1, 1), (1, 1), (2, 1)]
------------------First Radix sort started-------------------
M before sorting:  [(2, 1), (4, 1), (3, 1), (5, 1), (1, 1), (3, 1), (5, 1), (2, 1), (4, 1), (1, 1), (3, 1), (1, 1), (1, 1), (2, 1)]
Stuffed M:  [(2, 1), (4, 1), (3, 1), (5, 1), (1, 1), (3, 1), (5, 1), (2, 1), (4, 1), (1, 1), (3, 1), (1, 1), (1, 1), (2, 1)]
Idx:  M[j][i]:  1 i:  1 j:  0
Idx:  M[j][i]:  1 i:  1 j:  1
Idx:  M[j][i]:  1 i:  1 j:  2
Idx:  M[j][i]:  1 i:  1 j:  3
Idx:  M[j][i]:  1 i:  1 j:  4
Idx:  M[j][i]:  1 i:  1 j:  5
Idx:  M[j][i]:  1 i:  1 j:  6
Idx:  M[j][i]:  1 i:  1 j:  7
Idx:  M[j][i]:  1 i:  1 j:  8
Idx:  M[j][i]:  1 i:  1 j:  9
Idx:  M[j][i]:  1 i:  1 j:  10
Idx:  M[j][i]:  1 i:  1 