In [None]:
import torch 
from dataset_generator import prefixsum

import torch 
from torch_geometric.utils import to_networkx

import networkx as nx
import matplotlib.pyplot as plt
from graphviz import Digraph


In [None]:
path = "models/model_checkpoints/PrefixSum_3_mem_False.pt"
model = torch.load(path)


STATES_N = 3

def map_x_to_state(inputs, states_n):
    input_states = []

    for x in inputs:
        
        tensor = torch.zeros(states_n)
        tensor[2] = 1

        """
        # Prefix sum mapping
        tensor[2] = 1 if x[0] + x[2] == 2 else 0
        tensor[3] = 1 if x[1] + x[2] == 2 else 0
        tensor[4] = 1 if x[0] + x[3] == 2 else 0
        tensor[5] = 1 if x[1] + x[3] == 2 else 0
        """ 

        input_states.append(torch.unsqueeze(tensor, dim=0))
    return torch.cat(input_states, dim=0)

# ToDo: this is just a hack to retrieve multiple state machines -> define which state machine to use
def map_x_to_start_ids(inputs):

    ids = []
    for x in inputs:
        if x[0] + x[2] == 2:
            ids.append(0)
        elif x[1] + x[2] == 2:
            ids.append(1)
        elif x[0] + x[3] == 2:
            ids.append(2)
        else:
            ids.append(3)

    return ids



In [None]:
train_dataset = prefixsum.PrefixSum(num_graphs=4, num_nodes=4).data
model.set_iterations(10)
G = to_networkx(train_dataset[2])

# define color map. root_note = red, other_nodes = blue
def get_color(node_tensor):
    is_root = node_tensor[3]
    value_0 = node_tensor[0]

    if value_0:
        return "white"
    else:
        return "black"


def get_edge_color(node_tensor):
    is_root = node_tensor[3]
    if is_root:
        return "red"
    else:
        return "black"


color_map = [get_color(tensor) for tensor in train_dataset[2].x]        
edge_color_map = [get_edge_color(tensor) for tensor in train_dataset[2].x]


print("input picture")
nx.draw_kamada_kawai(G, node_color=color_map, edgecolors=edge_color_map)
plt.show()


print("ground truth")
color_map = ['white' if value == 0 else 'black' for value in train_dataset[2].y]     

nx.draw_kamada_kawai(G, node_color=color_map, edgecolors=edge_color_map)
plt.show()

print("prediction")
out = model(map_x_to_state(train_dataset[0].x, 6), train_dataset[0].edge_index)
out = torch.argmax(out, dim=-1)

color_map = ['#34495e' if out[i] == 0 else '#e74c3c' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


print("difference")
color_map = ['#2ecc71' if out[i] == train_dataset[0].y[i] else '#c0392b' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()


In [None]:
from graphviz import Digraph
from itertools import product
import math

# apply hardmax to T
T = torch.nn.functional.one_hot(
        torch.argmax(model.T, dim=-1), model.T.shape[-1]
    ).to(torch.float32)

G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

state_names = ['f0', 'f1', 's0', 's1', 's2', 's3', 'r1', 'r2']

G.node(state_names[0], penwidth='3px')
G.node(state_names[1], penwidth='3px')
G.node(state_names[2], penwidth='3px')
G.node(state_names[3], penwidth='3px')
G.node(state_names[4], penwidth='3px')
G.node(state_names[5], penwidth='3px')
G.node(state_names[6], penwidth='3px')
G.node(state_names[7], penwidth='3px')

for x, y in product(range(8), range(8)):
    edges = []
    for c in range(2**len(state_names)):
        if T[c][x][y] == 1:
            edge_string = "[" + str(format(c, '08b')) + "]" 
            G.edge(state_names[x], state_names[y], edge_string)

G.render('task0_ref_fsm', format='svg')

display(G)


### Visualize the statemachine of each node in each iteration

In [None]:
from graphviz import Digraph

train_dataset = prefixsum.PrefixSum(num_graphs=1, num_nodes=10).data
G = to_networkx(train_dataset[0])

# define color map. root_note = red, other_nodes = blue
color_map = ['#bdc3c7' if tensor[0] == 0 else '#2980b9' for tensor in train_dataset[0].x]        

print("input picture")
nx.draw_kamada_kawai(G, node_color=color_map, with_labels=True)
plt.show()


# set model to argmax and enable logging
model.set_hardmax(True)
model.set_logging(True)
model.set_iterations(12)

out, all_states, transitions = model(map_x_to_state(train_dataset[0].x, 6), train_dataset[0].edge_index)
out = torch.argmax(out, dim=-1)


print("difference")
color_map = ['#2ecc71' if out[i] == train_dataset[0].y[i] else '#c0392b' for i in range(len(out))]        
nx.draw_kamada_kawai(G, node_color=color_map)
plt.show()

return 

# draw state machine for each node
for i in range(10):
    print(f"state machine used {i}")

    state_machine = []

    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1', 's2', 's3'] #, 'r1']#, 'r2']

    states_used = [torch.argmax(state[i]) for state in all_states]
    for index in set(states_used):    

        G.node(state_names[index], penwidth='3px')

    curr_state_index = None

    for round, state in enumerate(all_states):
        state_index = torch.argmax(state[i])
        
        if curr_state_index is not None:
            transition = transitions[round-1][i].tolist()

            G.edge(state_names[curr_state_index], state_names[state_index], f"{round} - {transition}")


        curr_state_index = state_index


    G.render('task0_ref_fsm', format='svg')
    display(G)





In [None]:

G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

state_names = ['f0', 'f1', 's0', 's1', 's2', "s3", "r1"]

G.node(state_names[0], penwidth='3px')
G.node(state_names[1], penwidth='3px')
G.node(state_names[2], penwidth='3px')
G.node(state_names[3], penwidth='3px')
G.node(state_names[4], penwidth='3px')
G.node(state_names[5], penwidth='3px')
#G.node(state_names[6], penwidth='3px')

added = []

for n in range(1, 20):

    model.set_iterations(n+2)
    train_dataset = prefixsum.PrefixSum(num_nodes=n).data

    for j in range(2):
        
        # set model to argmax and enable logging
        model.set_hardmax(True)
        model.set_logging(True)

        indexes = train_dataset[j].edge_index[:2]

        out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, 6), train_dataset[j].edge_index)
        out = torch.argmax(out, dim=-1)

        for i in range(n):
            
            current_state = torch.argmax(all_states[0][i], dim=-1).item()

            for j in range(1, len(all_states[i])):

                next_state = torch.argmax(all_states[j][i], dim=-1).item()
                transition_value = str(transitions[j][i-1].tolist())

                transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                if transitionTuple not in added:
                    
                    if state_names[current_state] in ['f0', 'f1'] and state_names[next_state] not in ['f0', 'f1']:
                        print("for n", n, "i", i,  "j", j, "state", state_names[current_state], state_names[next_state])
                    
                    G.edge(state_names[current_state], state_names[next_state], transition_value)
                    added.append(transitionTuple)

                current_state = next_state

G.render('task0_ref_fsm', format='svg')

display(G)




In [None]:


for s in range(2, 6):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0', 's1', 's2', "s3", "r1"]

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')
    G.node(state_names[2], penwidth='3px')
    G.node(state_names[3], penwidth='3px')
    G.node(state_names[4], penwidth='3px')
    G.node(state_names[5], penwidth='3px')

    added = []

    for n in range(1, 20):

        model.set_iterations(n+2)
        train_dataset = prefixsum.PrefixSum(num_nodes=n).data

        for j in range(2):
            
            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, 6), train_dataset[j].edge_index)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if current_state != s:
                    continue

                for j in range(1, len(all_states[i])):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j][i-1].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if transitionTuple not in added:
                        
                        if state_names[current_state] in ['f0', 'f1'] and state_names[next_state] not in ['f0', 'f1']:
                            print("for n", n, "i", i,  "j", j, "state", state_names[current_state], state_names[next_state])
                        
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", s-2)
    G.render('task0_ref_fsm', format='svg')
    display(G)



In [None]:
# TODO: this uses some not so nice code -> we should introduce a seperate training pipeline!
# code to visualize multi-state machine model

for start_id in range(0,3):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')
    G.node(state_names[2], penwidth='3px')
    #G.node(state_names[3], penwidth='3px')

    added = []

    for n in range(3, 10):

        model.set_iterations(n+2)
        train_dataset = rootvalue.RootValue(num_nodes=n).data

        for j in range(len(train_dataset)):
            
            start_state_ids = map_x_to_start_ids(train_dataset[j].x)

            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, STATES_N), train_dataset[j].edge_index, start_state_ids)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if start_id != start_state_ids[i]:
                    continue

                for j in range(1, len(all_states[i])):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j-1][i].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if transitionTuple not in added:
                                                
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", start_id)
    G.render('task0_ref_fsm', format='svg')
    display(G)



In [None]:
# TODO: this uses some not so nice code -> we should introduce a seperate training pipeline!
# code to visualize multi-state machine model

for start_id in range(0,4):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['f0', 'f1', 's0']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')
    G.node(state_names[2], penwidth='3px')
    #G.node(state_names[3], penwidth='3px')

    added = []

    for n in range(3, 10):

        model.set_iterations(n+2)
        train_dataset = prefixsum.PrefixSum(num_nodes=n).data

        for j in range(len(train_dataset)):
            
            start_state_ids = map_x_to_start_ids(train_dataset[j].x)

            # set model to argmax and enable logging
            model.set_hardmax(True)
            model.set_logging(True)

            out, all_states, transitions = model(map_x_to_state(train_dataset[j].x, STATES_N), train_dataset[j].edge_index, start_state_ids)
            out = torch.argmax(out, dim=-1)

            for i in range(n):
                
                current_state = torch.argmax(all_states[0][i], dim=-1).item()
                if start_id != start_state_ids[i]:
                    continue

                for j in range(1, len(all_states[i])):

                    next_state = torch.argmax(all_states[j][i], dim=-1).item()
                    transition_value = str(transitions[j-1][i].tolist())

                    transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                    if transitionTuple not in added:
                                                
                        G.edge(state_names[current_state], state_names[next_state], transition_value)
                        added.append(transitionTuple)

                    current_state = next_state

    print("starting state s", start_id)
    G.render('task0_ref_fsm', format='svg')
    display(G)

