In [None]:
import torch 
from dataset_generator import pathfinding

import torch 
from torch_geometric.utils import to_networkx

import networkx as nx
import matplotlib.pyplot as plt
from graphviz import Digraph


In [None]:
path = "models/model_checkpoints/PathFinding_5_mem_False.pt"
model = torch.load(path)
STATES_N = 5

def map_x_to_state(inputs, states_n):
    input_states = []

    count = 0
    for x in inputs:
        
        tensor = torch.zeros(states_n)

        # RootValue mapping
        if x[0] == 1:
            tensor[2] = 1
        elif x[1] == 1 and count == 0:
            tensor[3] = 1
            count = 1
        else: 
            tensor[4] = 1

        input_states.append(torch.unsqueeze(tensor, dim=0))
    return torch.cat(input_states, dim=0)

# define color map. root_note = red, other_nodes = blue
def map_tensor_to_color(tensor):
    if tensor[0] == 1:
        return "#2980b9"
    elif tensor[1] == 1:
        return "red"
    else:
        return "grey"

# ToDo: this is just a hack to retrieve multiple state machines -> define which state machine to use
def map_x_to_start_ids(inputs):

    ids = []
    helper = 0

    for x in inputs:
        if x[0] == 1:
            ids.append(0)
        elif x[1] == 1 and helper == 0:
            ids.append(1)
            helper += 1
        else:
            ids.append(2)

    return ids


In [None]:
train_dataset = pathfinding.PathFinding(num_nodes=8).data
G = to_networkx(train_dataset[0])


color_map = [map_tensor_to_color(tensor) for tensor in train_dataset[0].x]        

print("input picture")
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


print("ground truth")
color_map = ['#2980b9' if value == 0 else 'red' for value in train_dataset[0].y]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


In [None]:
start_states = map_x_to_start_ids(train_dataset[0].x)

model.set_iterations(15)
out = model(map_x_to_state(train_dataset[0].x, STATES_N), train_dataset[0].edge_index)
out = torch.argmax(out, dim=-1)


print("prediction")
color_map = ['#34495e' if out[i] == 0 else '#e74c3c' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()
    
color_map = ['#2ecc71' if out[i] == train_dataset[0].y[i] else '#c0392b' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)

plt.show()

### Visualize the statemachine of each node in each iteration

In [None]:
from graphviz import Digraph

NUM_NODES = 7

model.set_iterations(9)


train_dataset = pathfinding.PathFinding(num_nodes=NUM_NODES, num_graphs=10).data
dataset_index = 4

start_states = map_x_to_start_ids(train_dataset[dataset_index].x)

train_dataset.reverse()
G = to_networkx(train_dataset[dataset_index])

color_map = [map_tensor_to_color(tensor) for tensor in train_dataset[dataset_index].x]        

nx.draw_kamada_kawai(G, node_color=color_map, with_labels=True)
plt.show()

# set model to argmax and enable logging
model.set_hardmax(True)
model.set_logging(True)

inputs = map_x_to_state(train_dataset[0].x, STATES_N)[:2]

out, all_states, transitions = model(map_x_to_state(train_dataset[dataset_index].x, STATES_N), train_dataset[dataset_index].edge_index, start_states)
out = torch.argmax(out, dim=-1)

color_map = ['#2ecc71' if out[i] == train_dataset[dataset_index].y[i] else '#c0392b' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


# draw state machine for each node
for i in range(NUM_NODES):

    state_machine = []

    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1', 's2']

    states_used = [torch.argmax(state[i]) for state in all_states]

    for index in set(states_used):    

        G.node(state_names[index], penwidth='3px')

    curr_state_index = None

    for round, state in enumerate(all_states):
        state_index = torch.argmax(state[i])
        
        if curr_state_index is not None:
            transition = transitions[round-1][i].tolist()

            G.edge(state_names[curr_state_index], state_names[state_index], f"{round} - {transition}")


        curr_state_index = state_index


    G.render('task0_ref_fsm', format='svg')
    display(G)




In [None]:

G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

state_names = ['f0', 'f1', 's0', 's1', 's2']

G.node(state_names[0], penwidth='3px')
G.node(state_names[1], penwidth='3px')
G.node(state_names[2], penwidth='3px')
G.node(state_names[3], penwidth='3px')
G.node(state_names[4], penwidth='3px')
#G.node(state_names[5], penwidth='3px')
#G.node(state_names[6], penwidth='3px')

added = []

for n in range(3, 20):

    model.set_iterations(n+2)
    num_graphs = 5
    train_dataset = pathfinding.PathFinding(num_nodes=n, num_graphs=num_graphs).data

    for j in range(num_graphs):
        
        # set model to argmax and enable logging
        model.set_hardmax(True)
        model.set_logging(True)

        out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, STATES_N), train_dataset[j].edge_index)
        out = torch.argmax(out, dim=-1)

        for i in range(n):
            
            current_state = torch.argmax(all_states[0][i], dim=-1).item()

            for j in range(1, len(all_states[i])):

                next_state = torch.argmax(all_states[j][i], dim=-1).item()
                transition_value = str(transitions[j][i-1].tolist())

                transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                if transitionTuple not in added:
                                        
                    G.edge(state_names[current_state], state_names[next_state], transition_value)
                    added.append(transitionTuple)

                current_state = next_state

G.render('task0_ref_fsm', format='svg')

display(G)



### Visualization of entire state machine


In [None]:
from graphviz import Digraph
from itertools import product
import math

# apply hardmax to T
T = torch.nn.functional.one_hot(
        torch.argmax(model.T, dim=-1), model.T.shape[-1]
    ).to(torch.float32)

G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

state_names = ['f0', 'f1', 's0', 's1', 's2'] #, "r1"]

G.node(state_names[0], penwidth='3px')
G.node(state_names[1], penwidth='3px')
G.node(state_names[2], penwidth='3px')
G.node(state_names[3], penwidth='3px')
G.node(state_names[4], penwidth='3px')
#G.node(state_names[5], penwidth='3px')


for x, y in product(range(5), range(5)):
    edges = []
    for c in range(2**len(state_names)):
        if T[c][x][y] == 1:
            edges.append(c)

    edge_strings = []
    for e1 in edges:

        # small helper function
        def bin_count(n):
            c = 0
            while n:
                c += n & 1
                n = n >> 1
            return c

        edge_string = list("[" + str(format(e1, '05b')) + "]")

        # TODO: we can merge the edges even more
        # first we check if there is the same edge with just binary number different 
        # if thats the case we merge the edges and represent the char with * e.g. [1,0,0] and [1,0,1] will be edge [1,0,*]
        for e2 in edges:
            if bin_count(e1 ^ e2) == 1:
                position = math.log2((e1 ^ e2))
                edge_string[-int(position) -2] = "*"
                edge_strings.append("".join(edge_string))
                break
    
    for edge in set(edge_strings):
        G.edge(state_names[x], state_names[y], "".join(edge))


G.render('task0_ref_fsm', format='svg')

display(G)


In [None]:
for s in range(2, 5):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1', 's2']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')
    G.node(state_names[2], penwidth='3px')
    G.node(state_names[3], penwidth='3px')
    G.node(state_names[4], penwidth='3px')

    added = []

    for n in range(3, 20):

        model.set_iterations(n+2)
        num_graphs = 10
        train_dataset = pathfinding.PathFinding(num_nodes=n, num_graphs=num_graphs).data

        for j in range(num_graphs):
            
            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, STATES_N), train_dataset[j].edge_index)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if current_state != s:
                    continue

                for j in range(1, len(all_states[i])):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j][i-1].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if transitionTuple not in added:
                        
                        if state_names[current_state] in ['f0', 'f1'] and state_names[next_state] not in ['f0', 'f1']:
                            print("for n", n, "i", i,  "j", j, "state", state_names[current_state], state_names[next_state])
                        
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", s-2)
    G.render('task0_ref_fsm', format='svg')
    display(G)



In [None]:
# TODO: this uses some not so nice code -> we should introduce a seperate training pipeline!
# code to visualize multi-state machine model

for start_id in range(0,4):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1', 's2']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')
    G.node(state_names[2], penwidth='3px')
    G.node(state_names[3], penwidth='3px')
    G.node(state_names[4], penwidth='3px')

    added = []

    for n in range(3, 10):

        model.set_iterations(n+2)
        train_dataset = pathfinding.PathFinding(num_nodes=n).data

        for j in range(len(train_dataset)):
            
            start_state_ids = map_x_to_start_ids(train_dataset[j].x)

            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, STATES_N), train_dataset[j].edge_index, start_state_ids)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if start_id != start_state_ids[i]:
                    continue

                for j in range(1, len(all_states[i])):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j-1][i].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if transitionTuple not in added:
                                                
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", start_id)
    G.render('task0_ref_fsm', format='svg')
    display(G)

