In [None]:
import torch 
from dataset_generator import distance

import torch 
from torch_geometric.utils import to_networkx

import networkx as nx
import matplotlib.pyplot as plt

from graphviz import Digraph
from itertools import product
import math


In [None]:
STATES_N = 4

path = "runs/Distance/Distance_4_mem_False_1697132227.167031.pt"
model = torch.load(path)

def map_x_to_state(inputs, states_n):
    input_states = []

    for x in inputs:
        
        tensor = torch.zeros(states_n)
        # Distance mapping
        tensor[2] = x[0] 
        tensor[3] = x[1] 

        input_states.append(torch.unsqueeze(tensor, dim=0))
    return torch.cat(input_states, dim=0)

# ToDo: this is just a hack to retrieve multiple state machines -> define which state machine to use
def map_x_to_start_ids(inputs):
        ids = []
        for input in inputs:
            if input[0] == 1:
                ids.append(0)
            else:
                ids.append(1)

        return ids


In [None]:
train_dataset = distance.Distance(num_graphs=1, num_nodes=20).data
G = to_networkx(train_dataset[0])

# define color map. root_note = red, other_nodes = blue
color_map = ['#bdc3c7' if tensor[0] == 0 else '#2980b9' for tensor in train_dataset[0].x]        

print("input picture")
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


print("ground truth")
color_map = ['#34495e' if value == 0 else '#e74c3c' for value in train_dataset[0].y]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()

model.set_logging(False)
model.set_iterations(10)
out = model(map_x_to_state(train_dataset[0].x, STATES_N), train_dataset[0].edge_index)
out = torch.argmax(out, dim=-1)

print("prediction")
color_map = ['#34495e' if out[i] == 0 else '#e74c3c' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()

print("difference")
color_map = ['#2ecc71' if out[i] == train_dataset[0].y[i] else '#c0392b' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


In [None]:
from graphviz import Digraph
from itertools import product
import math

# apply hardmax to T
T = torch.nn.functional.one_hot(
        torch.argmax(model.T, dim=-1), model.T.shape[-1]
    ).to(torch.float32)

G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

state_names = ['f0', 'f1', 's0', 's1']

G.node(state_names[0], penwidth='3px')
G.node(state_names[1], penwidth='3px')
G.node(state_names[2], penwidth='3px')
G.node(state_names[3], penwidth='3px')
#G.node(state_names[4], penwidth='3px')


for x, y in product(range(4), range(4)):
    edges = []
    for c in range(2**len(state_names)):
        if T[c][x][y] == 1:
            edges.append(c)

    edge_strings = []
    for e1 in edges:

        # small helper function
        def bin_count(n):
            c = 0
            while n:
                c += n & 1
                n = n >> 1
            return c

        edge_string = list("[" + str(format(e1, '05b')) + "]")

        # TODO: we can merge the edges even more
        # first we check if there is the same edge with just binary number different 
        # if thats the case we merge the edges and represent the char with * e.g. [1,0,0] and [1,0,1] will be edge [1,0,*]
        
        #for e2 in edges:
        #    if bin_count(e1 ^ e2) == 1:
        #position = math.log2((e1 ^ e2))
        #edge_string[-int(position) -2] = "*"
        edge_strings.append("".join(edge_string))
        #        break

    for edge in set(edge_strings):
        if state_names[x] in ["f0", "f1"]:
            continue
        G.edge(state_names[x], state_names[y], "".join(edge))


G.render('task0_ref_fsm', format='svg')

display(G)


### Visualize the statemachine of each node in each iteration

In [None]:
# IMPORTANT!
# check below for the case where we train multiple state machines

train_dataset = distance.Distance(num_graphs=1, num_nodes=10).data +  distance.Distance(num_graphs=1, num_nodes=1).data
G = to_networkx(train_dataset[0])

# define color map. root_note = red, other_nodes = blue
color_map = ['#bdc3c7' if tensor[0] == 0 else '#2980b9' for tensor in train_dataset[0].x]        

print("input picture")
nx.draw_kamada_kawai(G, node_color=color_map, with_labels=True)
plt.show()


# set model to argmax and enable logging
model.set_hardmax(True)
model.set_logging(True)
model.set_iterations(6)

inputs = map_x_to_state(train_dataset[0].x, 4)[:2]
indexes = train_dataset[0].edge_index[:2]

out, all_states, transitions = model(map_x_to_state(train_dataset[0].x, 4), train_dataset[0].edge_index)
out = torch.argmax(out, dim=-1)

print("difference")
color_map = ['#2ecc71' if out[i] == train_dataset[0].y[i] else '#c0392b' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


# draw state machine for each node
for i in range(10):
    print(f"state machine used {i}")

    state_machine = []

    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1']

    states_used = [torch.argmax(state[i]) for state in all_states]

    for index in set(states_used):    

        G.node(state_names[index], penwidth='3px')

    curr_state_index = None

    for round, state in enumerate(all_states):
        state_index = torch.argmax(state[i])
        
        if curr_state_index is not None:
            transition = transitions[round-1][i].tolist()

            G.edge(state_names[curr_state_index], state_names[state_index], f"{round} - {transition}")


        curr_state_index = state_index


    G.render('task0_ref_fsm', format='svg')
    display(G)




In [None]:


for s in range(2, 6):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.9"})

    state_names = ['f0', 'f1', 's0', 's1']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')
    G.node(state_names[2], penwidth='3px')
    G.node(state_names[3], penwidth='3px')

    added = []

    for n in range(2, 20):

        model.set_iterations(n+2)
        train_dataset = distance.Distance(num_nodes=n, num_graphs=2).data + distance.Distance(num_nodes=1, num_graphs=1)

        for j in range(2):
            
            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, 4), train_dataset[j].edge_index)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if current_state != s:
                    continue

                for j in range(1, len(all_states[i])):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j-1][i].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if transitionTuple not in added:
                                                
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", s-2)
    G.render('task0_ref_fsm', format='svg')
    display(G)



In [None]:
# extracting multiple state machines (we trained a state machine for each input class)

for start_id in range(0,2):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')# , style="doublecircle")
    G.node(state_names[2], penwidth='3px')
    G.node(state_names[3], penwidth='3px')

    added = []

    for n in range(2, 12):

        model.set_iterations(n+2)
        train_dataset = distance.Distance(num_nodes=n).data 

        for j in range(len(train_dataset)):
            
            start_state_ids = map_x_to_start_ids(train_dataset[j].x)

            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, STATES_N), train_dataset[j].edge_index)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if start_id != start_state_ids[i]:
                    continue

                for j in range(1, len(all_states)):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j-1][i].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if state_names[current_state] in ["f0", "f1"]:
                        continue 

                    if transitionTuple not in added:
                                                
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", start_id)
    G.render('task0_ref_fsm', format='svg')
    display(G)