## 1-dimensional Weisfeiler-Lehman algorithm implementation


### Structure:

Input - graph G

Output - stable colouring of the graph nodes

#### Variables:

-   G: graph
-   C[v]: list of colours of nodes, initially all set to same colour to 1.
-   M: multiset, initially empty
-   L: work list, this consists of the new colour classes that are introduced in the latest iteration (initially set to initial colour).

#### Functions:

-   FindNeighboursWithColourC:
    -   To find the neighbourhood of the nodes with colour X
    -   Outputs a list of vertices which are the neighbours of a particular node with colour X
    
-   TupleTransformation:
    -   To transform multi-tuple multiset to for a single-tuple multiset 

-   StuffZeros:
    -   To stuff zeros at the end of the multiset for radix sort
    
-   PreprocessTuples:
    -   To remove stuffed zeros

-   RadixSort:
    -   Implement radix sort
    
-   OneWL:
    -   Main function where all the logic is implemented


#### Steps of the Algorithm:

-   Initialise the graph G
-   Initialise the colouring list C[v] for each node v. Initially to 1
-   Initialise the multiset M to NULL
-   Initialise the work list L to initial colouring
-   While L is not empty
-   Pop the first element from L
-   For the colour class c popped, find all the vertices of that colour
-   For each vertex w in the colour class, find the neighbourhood of w, if not already visited
-   For each vertex v in the neighbourhood of w, add v to the multiset M. Add neighbour node and current colour class, i.e. add(v, c)
-   If M is not empty, Sort the multiset M using radix sort
-   For each element (v, c) in the sorted multiset M
-   Replace the tuples in M from (v, c1), ... (v, cr) to a single tuple (C(v), c1, ...cr, v)
-   Apply radix sort on M
-   For each element (C(v), c1, ...cr, v) in the sorted multiset, update the colours of the nodes. Leave the largest part to c(prev colour) and update the other to new colours.
-   Update the work list L with the new colour classes
-   Make multiset NULL


In [1]:
%pip install networkx


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import networkx as nx

In [40]:
class WL:
    def __init__(self, G, L, M):
        self.G = G
        self.M = M
        self.L = L
            

    def FindNeighboursWithColourC(self, colour):
        color_attributes = nx.get_node_attributes(self.G, 'colour') #return a dict with key as nodes and value as their colour
        # print("color_attributes: ", color_attributes)

        for i in self.G.nodes: #loop 3
            if int(color_attributes[i]) == colour:
                #loop to find the neighbours of the node with colour c
                neighboutList = list(self.G.neighbors(i)) #inplace of loop 4
                for j in neighboutList:
                    tempTuple = (j, colour)
                    self.M.append(tempTuple)


    def TupleTransformation(self):
        keys = self.M.copy()
        curNode = self.M[0][0]

        colourList = ()
        nodeColour = -1

        j=0
        self.M.clear()
        for i in range(len(keys)):
            nodeColour = int(self.G.nodes[curNode]['colour'])
            colourList = colourList + (nodeColour, )
            while j<len(keys):
                if(keys[j][0] == curNode):
                    node = (keys[j][1], )  
                    colourList = colourList + node
                    # print("keys[j]: ", keys[j])
                else:
                    colourList = colourList + (curNode, )
                    curNode = keys[j][0]
                    # print("curnode: ", curNode, "node[0]", keys[j][0])
                    break
                j+=1
            if(j == len(keys)):
                    colourList = colourList + (curNode, )
                    
            #now here insert the newly created tuple to the multiset 
            self.M.append(colourList)
            colourList = tuple()
            if(j==len(keys)):
                break


    def StuffZeros(self, M):
        maxLen=-1
        for i in M:
            maxLen = max(maxLen, len(i))

        M_padded = []
        for i in M:
            if len(i) < maxLen:
                padded_list = list(i) + [0] * (maxLen - len(i))
                M_padded.append(tuple(padded_list))
            else:
                M_padded.append(i)
        
        self.M = M_padded.copy()
        # print(self.M)


    #To remove stuffed zeros
    def PreprocessTuples(self):
        newList = []
        for j in range(len(self.M)):
            tempTuple = ()
            for i in self.M[j]:
                if i != 0:
                    tempTuple = tempTuple + (i, )
            newList.append(tempTuple)
        self.M = newList


    def RadixSort(self, M):
        #number of tuples in the list keys
        n = len(self.M)
        if n == 0:
            return []
        
        #function call for stuffing if required
        self.StuffZeros(M)
        print("Stuffed M: ", self.M)

        numberOfVal = len(self.M[0]) #values in each tuple in the multiset
        #itr from numberOfVal-1 to 0, we can access each value in each tuple from the end as demanded by radix sort.

        #to temporarily store the tuples to sort
        bins = [[] for _ in range(10)]

        for i in range(numberOfVal-1, -1, -1):
            for j in range(n):
                # print("Idx: ", "M[j][i]: ", self.M[j][i], "i: ", i, "j: ", j)
                idx = self.M[j][i]
                bins[idx].append(self.M[j])

            self.M.clear()
            for j in range(10):
                for itr in bins[j]:
                    tempTuple = (itr)
                    self.M.append(tempTuple)
                bins[j].clear()

        return self.M


    def UpdateColour(self, previousColour, newColour):
        embeddings = {}
        maxCount=-1
        for i in self.M:
            node = i[-1]
            embedding = i[:-1]
            count=0
            # print(i[:-1], node)
            if embedding in embeddings:
                # count = embeddings[embedding]['count'] + 1
                # embeddings[embedding]['count'] = count
                embeddings[embedding]['nodes'].append(node)
            else:
                embeddings[embedding] = {
                # 'count': 1,
                'nodes': [node]
                }
                count=1
            maxCount = max(maxCount, count)
        print("embeddings: ", embeddings)

        embeddingKeys = embeddings.keys()
        print("embedding keys : ", embeddingKeys)
        print("previous colour : ", previousColour)

        #identify the nodes whose embeddings are changes from previous iteration
        modifiedTuples = {}
        for key in embeddingKeys:
            if previousColour[str(key[0])] != embeddings[key]['nodes']:
                if key[0] not in modifiedTuples:
                    modifiedTuples[key[0]] = [embeddings[key]['nodes']]
                else:
                    modifiedTuples[key[0]].append(embeddings[key]['nodes'])
        print("modifiedTuples: ", modifiedTuples)

        #identify the longest tuple in the for a particular colour
        maxLenTupleInAColour = {}
        for k, v in modifiedTuples.items():
            maxVal = -1
            for l in v:
                maxVal = max(maxVal, len(l))
            maxLenTupleInAColour[k] = maxVal
        # print("maxLenTupleInAColour: ", maxLenTupleInAColour)
        
        #remove one of the longest tuple
        for k in maxLenTupleInAColour.keys():
            for v in modifiedTuples[k]:
                if(len(v) == maxLenTupleInAColour[k]):
                    modifiedTuples[k].remove(v)
        print("modifiedTuple after removal: ", modifiedTuples)

        #Update the node colour
        for k, v in modifiedTuples.items():
            for i in v:
                for j in i:
                    # print("j: ", j, "node: ", self.G.nodes[j]['colour'])
                    self.G.nodes[j]['colour'] = str(newColour+1)
                    # print("j: ", j, "node: ", self.G.nodes[j]['colour'])

                newColour = newColour+1
                self.L.append(newColour)
        self.M.clear()



    def OneWL(self):
        #iterates on the colour classes
        newColour = self.L[-1]

        previousColour = {} #set previous colours of nodes
        color_attributes = nx.get_node_attributes(self.G, 'colour').items()
        for k, v in color_attributes:
            if v not in previousColour:
                previousColour[v] = [k]
            else:
                previousColour[v].append(k)
        print("prev_colour: ", previousColour)


        print("newColour: ", self.L)
        while self.L: #loop 1
            print("----------------START of WorkList----------------")
            print(" L and M: ", self.L, self.M)

            # test=0
            LSize = len(self.L)
            for i in range(LSize): #loop 2
                #get the colour class at the front
                print("L size before pop: ", len(self.L))
                colour = self.L.pop(0)
                #loop to get all the nodes from the graph with colour c
                print("nodes: ", self.G.nodes)
                self.FindNeighboursWithColourC(colour) #adding neighbours to the multiset


            #first radix sort
            print("------------------First Radix sort started-------------------")
            print("M before sorting: ", self.M)
            self.RadixSort(self.M)        
            print("M after sorting: ", self.M)
            print("------------------First Radix sort completed-------------------")


            #Function call to transform multi-tuple multiset to for a single-tuple multiset 
            self.TupleTransformation()
            print("M after transformation: ", self.M)


            #second radix sort
            print("------------------Second Radix sort started-------------------")
            print("M before sorting: ", self.M)
            self.RadixSort(self.M)
            print("M after sorting: ", self.M)
            print("------------------Second Radix sort completed-------------------")

            #Remove the zeroStuffing of tuples
            print("before processing: ", self.M)
            self.PreprocessTuples()
            print("post-processing: ", self.M)
            print("-----------------------------------------------------------")


            #generating new tuple embedding(colour)
            self.UpdateColour(previousColour, newColour)
            print("------------------Colour updation function executed-------------------")

            # update the previous colouring information
            color_attributes = nx.get_node_attributes(self.G, 'colour') #return a dict with key as nodes and value as their colour

            previousColour = {}
            for k, v in color_attributes.items():
                if v not in previousColour:
                    previousColour[v] = [k]
                else:
                    previousColour[v].append(k)
            print("Updated colours: ", color_attributes)


        return 
    



In [41]:
# # Graph 0
# G = nx.Graph()
# G.add_node(1, colour='1')
# G.add_node(2, colour='1')
# G.add_node(3, colour='1')
# G.add_node(4, colour='1')
# G.add_node(5, colour='1')


# G.add_edge(1, 2)
# G.add_edge(2, 3)
# G.add_edge(3, 4)
# G.add_edge(4, 1)
# G.add_edge(1, 3)
# G.add_edge(1, 5)
# G.add_edge(2, 5)


# M = []
# workList = [1]



In [42]:
# Graph 1
G = nx.Graph()
G.add_node(1, colour='1')
G.add_node(2, colour='1')
G.add_node(3, colour='1')
G.add_node(4, colour='1')


G.add_edge(1, 2)
G.add_edge(1, 3)
G.add_edge(2, 3)
G.add_edge(4, 3)


M = []
workList = [1]



In [43]:
obj = WL(G, workList, M)
obj.OneWL()

prev_colour:  {'1': [1, 2, 3, 4]}
newColour:  [1]
----------------START of WorkList----------------
 L and M:  [1] []
L size before pop:  1
nodes:  [1, 2, 3, 4]
------------------First Radix sort started-------------------
M before sorting:  [(2, 1), (3, 1), (1, 1), (3, 1), (1, 1), (2, 1), (4, 1), (3, 1)]
Stuffed M:  [(2, 1), (3, 1), (1, 1), (3, 1), (1, 1), (2, 1), (4, 1), (3, 1)]
M after sorting:  [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1), (3, 1), (3, 1), (4, 1)]
------------------First Radix sort completed-------------------
M after transformation:  [(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 1, 3), (1, 1, 4)]
------------------Second Radix sort started-------------------
M before sorting:  [(1, 1, 1, 1), (1, 1, 1, 2), (1, 1, 1, 1, 3), (1, 1, 4)]
Stuffed M:  [(1, 1, 1, 1, 0), (1, 1, 1, 2, 0), (1, 1, 1, 1, 3), (1, 1, 4, 0, 0)]
M after sorting:  [(1, 1, 1, 1, 0), (1, 1, 1, 1, 3), (1, 1, 1, 2, 0), (1, 1, 4, 0, 0)]
------------------Second Radix sort completed-------------------
before processi

In [None]:
# Graph 2
# G = nx.Graph()
# G.add_node(1, colour='1')
# G.add_node(2, colour='1')
# G.add_node(3, colour='1')
# G.add_node(4, colour='1')


# G.add_edge(1, 2)
# G.add_edge(2, 3)
# G.add_edge(3, 4)
# G.add_edge(4, 1)
# G.add_edge(1, 3)


# M = []
# workList = [1]



In [None]:
# # Graph 3
# G = nx.Graph()
# G.add_node(1, colour='1')
# G.add_node(2, colour='1')
# G.add_node(3, colour='1')
# G.add_node(4, colour='1')
# G.add_node(5, colour='1')
# G.add_node(6, colour='1')
# G.add_node(7, colour='1')
# G.add_node(8, colour='1')


# G.add_edge(1, 2)
# G.add_edge(1, 5)
# G.add_edge(1, 3)
# G.add_edge(2, 4)
# G.add_edge(2, 6)
# G.add_edge(3, 4)
# G.add_edge(3, 7)
# G.add_edge(4, 8)
# G.add_edge(5, 6)
# G.add_edge(7, 8)


# M = []
# workList = [1]



In [None]:
# print(list(G.nodes[1][colour]))
print(list(G.nodes[1]['colour']))  # Output: 'red'
print(list(G.edges))
color_attributes = nx.get_node_attributes(G, 'colour')
print(color_attributes) 
print(workList)

['2']
[(1, 2), (1, 4), (1, 3), (1, 5), (2, 3), (2, 5), (3, 4)]
{1: '2', 2: '1', 3: '1', 4: '3', 5: '3'}
[]


In [None]:
#Alternative way to make the graph
G = nx.Graph()
G.add_nodes_from([(1, {'colour': '1'}), (2, {'colour': '1'}), (3, {'colour': '1'}), (4, {'colour': '1'}), (5, {'colour': '1'})])
G.add_edges_from([(1, 2), (2, 3), (3, 4), (4, 1), (1, 3), (1, 5), (2, 5)])


M = []
workList = [1]
