
Copyright (c) 2022 Imtiaz Karim & Abdullah Al Ishtiaq

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


In [None]:
from collections import defaultdict
import sys
import os

In [None]:
LL_DOT_FILENAME = "input/anonymous_device_ll.dot"
SMP_DOT_FILENAME = "input/anonymous_device_smp.dot"
RECON_DOT_FILENAME = "input/anonymous_device_re.dot"

OUTPUT_FILENAME = "anonymous_device_merged.dot"


In [None]:
LL_FINAL_IN_SYMBOL = "pair_req"
LL_FINAL_OUT_SYMBOL = "pair_resp"
SMP_FINAL_IN_SYMBOL = "dh_check"
SMP_FINAL_OUT_SYMBOL = "dh_key_response"

#print(ll_graph.print())
#print(ll_graph.get_nodes())
#print(ll_graph.get_root())
#print(ll_graph.find_target_states('pair_resp'))

In [None]:
replace_dict = {
    "feature_req_none": "feature_none_req",
    "mtu_req_llid_zero": "mtu_llid_zero_req",
    "mtu_req_mtu_zero": "mtu_mtu_zero_req",
    "length_req_time_zero": "length_time_zero_req",
    "length_req_rx_tx_zero": "length_rx_tx_zero_req",
    "con_req_interval_zero": "con_interval_zero_req",
    "con_req_crc_zero": "con_crc_zero_req",
    "con_req_length_zero": "con_length_zero_req",
    "con_req_channel_map_zero": "con_channel_map_zero_req",
    "con_req_timeout_zero": "con_timeout_zero_req",
    "version_req_llid_zero": "version_llid_zero_req",
    "version_req_max_len": "version_max_len_req",
    "pair_req_oob": "pair_oob_req",
    "pair_req_keyboard_display": "pair_keyboard_display_req",
    "pair_req_display_yes_no": "pair_display_yes_no_req",
    "pair_req_no_sc": "pair_no_sc_req",
    "pair_req_no_sc_keyboard_display": "pair_keyboard_display_no_sc_req",
    "pair_req_no_sc_display_yes_no": "pair_display_yes_no_no_sc_req",
    "pair_req_key_zero": "pair_key_zero_req",
    "key_exchange_invalid": "key_invalid_exchange",
    "dh_check_invalid": "dh_invalid_check",
    "pair_confirm_wrong_value": "pair_wrong_value_confirm",
    "enc_pause_resp_plain": "enc_plain_pause_resp",
    "enc_pause_req_plain": "enc_plain_pause_req",
    "start_enc_resp_plain": "start_enc_plain_resp"
}

In [None]:
class FSMGraph:
    def __init__(self):
        self.edges = defaultdict(list)
        self.root = None

    def set_root(self, root):
        self.root = root

    def get_root(self):
        return self.root

    def add_edge(self, u, v, in_symbol, out_symbol):
        self.edges[u].append((v, in_symbol, out_symbol))

    def get_edges(self):
        return self.edges

    def get_edges_at(self, node):
        return self.get_edges()[node]

    def get_nodes(self):
        return list(self.get_edges().keys())

    def get_num_nodes(self):
        return len(self.get_nodes())

    def print(self):
        return self.get_edges()

    def dot2graph(self, dot_filename, node_prefix):
        dot_file = open(dot_filename, 'r')
        dot_file_lines = dot_file.readlines()
        dot_file.close()

        for line in dot_file_lines:
            if "->" in line:
                line_splits = line.strip().split()
                node1 = node_prefix+line_splits[0]
                node2 = node_prefix+line_splits[2]

                if "__start0" in node1:
                    self.set_root(node2.replace(";",""))
                    continue

                in_symbol = line_splits[3].replace("[label=\"", "").strip()
                if in_symbol in replace_dict:
                    in_symbol = replace_dict[in_symbol]

                out_symbol = line_splits[5].replace("\"];", "").strip()

                self.add_edge(node1, node2, in_symbol, out_symbol)

    def BFS_distance(self):							# Function to print a BFS of graph
        if self.root is None:
            print("Graph not initialized properly")
            return None

        all_edges = self.get_edges()
        all_nodes = list(all_edges.keys())

        s = self.root
        queue = []								# Create a queue for BFS
        distance_dict = {}						# distance dictionary

        visited = {} 							# Mark all the vertices as not visited
        for k in list(all_nodes):
            visited[k] = False

        distance_dict[s] = 0
        visited[s] = True						# Mark the source node as visited and enqueue it
        queue.append(s)

        while queue:
            s = queue.pop(0)					# Dequeue a vertex from queue and print it
            # print (s)
            for edge in all_edges[s]:
                v, _, _ = edge
                if visited[v] == False:
                    queue.append(v)
                    visited[v] = True
                    distance_dict[v] = distance_dict[s]+1

        return distance_dict

    def find_target_states(self, target_in_symbol: str, target_out_symbol: str) -> list:
        all_edges = self.get_edges()
        all_nodes = list(all_edges.keys())

        target_states = set()

        for node in all_nodes:
            for edge in all_edges[node]:
                v, in_symbol, out_symbol = edge
                if in_symbol == target_in_symbol and out_symbol == target_out_symbol:
                    target_states.add(node)

        return target_states

    def find_furthest_target_state(self, target_in_symbol: str, target_out_symbol: str):
        target_states = self.find_target_states(target_in_symbol, target_out_symbol)
        distances = self.BFS_distance()

        print("Found states :", target_states)
        print("Found distances :", distances)

        if len(target_states) < 1:
            return ""

        max_node = ""
        max_dist = -1

        for node in target_states:
            if distances[node] > max_dist:
                max_node = node
                max_dist = distances[node]

        return max_node

    def BFS_plain(self):							# Function to print a BFS of graph
        if self.root is None:
            print("Graph not initialized properly")
            return None

        root = self.root
        all_edges = self.get_edges()

        queue = []								# Create a queue for BFS
        visited = set() 						# Mark all the vertices as not visited
        queue.append(root)
        visited.add(root) 			            # Mark the source node as visited and enqueue it

        while queue:
            s = queue.pop(0)					# Dequeue a vertex from queue and print it

            for edge in all_edges[s]:
                k, _, _ = edge
                if k not in visited:
                    queue.append(k)
                    visited.add(k)

        return visited

    def find_unreachable_nodes(self):
        all_nodes = set(self.get_nodes())
        reachable_nodes = self.BFS_plain()

        unreachable_nodes = all_nodes.difference(reachable_nodes)
        return unreachable_nodes

    def remove_node(self, target_node):
        del self.edges[target_node]

        for node1 in list(self.get_nodes()):
            new_edges = []
            for edge in self.edges[node1]:
                node2, in_symbol, out_symbol = edge
                if node2 != target_node:
                    new_edges.append((node2, in_symbol, out_symbol))
            self.edges[node1] = new_edges

        print("Done removing node", target_node)


    def remove_in_symbol(self, target_in_symbol):
        for node1 in list(self.get_nodes()):
            new_edges = []
            for edge in self.edges[node1]:
                node2, in_symbol, out_symbol = edge
                if in_symbol != target_in_symbol:
                    new_edges.append((node2, in_symbol, out_symbol))
            self.edges[node1] = new_edges

        unreachable_nodes = self.find_unreachable_nodes()
        while len(unreachable_nodes) > 0:
            for u_node in unreachable_nodes:
                self.remove_node(u_node)
            unreachable_nodes = self.find_unreachable_nodes()

        print("Done removing symbol", target_in_symbol)

    def add_fsm(self, final_in_symbol: str, final_out_symbol:str, next_fsm, transition_in_symbol = "", transition_out_symbol = ""):
        final_state = self.find_furthest_target_state(final_in_symbol, final_out_symbol)
        if final_state == "":
            print("Cannot find target state")
            return

        final_states = [final_state]

        self_states = self.get_nodes()
        next_fsm_states = next_fsm.get_nodes()
        next_fsm_root = next_fsm.get_root()
        next_fsm_edges = next_fsm.get_edges()

        if(len(set(self_states).intersection(set(next_fsm_states)))) > 0:
            print("Change state names. Initialize with different prefixes...")
            print("Merge failed...")
            return

        # copy next fsm root state behavior to final states
        if transition_in_symbol == "" or transition_out_symbol == "":
            self.remove_in_symbol(final_in_symbol)          # avoid conflict
            for f_state in final_states:
                for edge in next_fsm_edges[next_fsm_root]:
                    self.edges[f_state].append(edge)

            del next_fsm_edges[next_fsm_root]
            self.edges.update(next_fsm_edges)

            for state in self.edges:
                for idx, edge in enumerate(self.edges[state]):
                    v, in_symbol, out_symbol = edge
                    if v == next_fsm_root:
                        self.edges[state][idx] = (final_states[0], in_symbol, out_symbol)

        # add transition to next fsm root state
        else:
            for f_state in final_states:
                self.edges[f_state].append((next_fsm_root, transition_in_symbol, transition_out_symbol))
            self.edges.update(next_fsm_edges)

        print("FSM Merged -> final_in_symbol :", final_in_symbol, ", final_out_symbol :", final_out_symbol, ", final states :", final_states)


    def graph2dot(self, output_filename):
        out_file = open(output_filename, 'w')

        out_file.write("digraph g {\n")
        out_file.write("__start0 [label=\"\" shape=\"none\"];\n")
        out_file.write("\n")

        all_nodes = self.get_nodes()
        for node in all_nodes:
            node = node.strip()
            if node == "":
                continue
            out_str = "\t{} [shape=\"circle\" label=\"{}\"];".format(node, node)
            out_file.write(out_str+"\n")
        out_file.write("\n")

        all_edges = self.get_edges()
        for node in all_nodes:
            for edge in all_edges[node]:
                v, in_symbol, out_symbol = edge
                out_str = "\t{} -> {} [label=\"{} / {}\"];".format(node, v, in_symbol, out_symbol)
                out_file.write(out_str+"\n")
            out_file.write("\n")

        out_file.write("__start0 -> {};\n".format(self.get_root()))
        out_file.write("}\n")

        out_file.close()
        print("Graph generated at :", output_filename)
        

In [None]:
ll_graph = FSMGraph()
ll_graph.dot2graph(LL_DOT_FILENAME, "ll_")




In [None]:
smp_graph = FSMGraph()
smp_graph.dot2graph(SMP_DOT_FILENAME, "smp_")


In [None]:
recon_graph = FSMGraph()
recon_graph.dot2graph(RECON_DOT_FILENAME, "recon_")


In [None]:
smp_graph.add_fsm(SMP_FINAL_IN_SYMBOL, SMP_FINAL_OUT_SYMBOL, recon_graph, "discon_req", "null_action")


In [None]:
ll_graph.add_fsm(LL_FINAL_IN_SYMBOL, LL_FINAL_OUT_SYMBOL, smp_graph)


In [None]:
ll_graph.graph2dot(OUTPUT_FILENAME)