In [None]:
import graphviz
import json
import sqlite3
import sys, os, argparse
import math
import time
import collections

In [None]:
def run_sql_query(db, query):
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(query)
    result_data = json.dumps(cur.fetchall())
    result_json = json.loads(result_data)
    return result_json

In [None]:
def gen_records(graph, nodes, edge_dict_list):
    graph.attr('node', shape='record', fontname='Courier', size='6,6')
    graph.node_attr['fontname'] = "Courier"
    graph.node_attr['color'] = "lightgreen"
    graph.node_attr['fillcolor'] = "lightblue"
    graph.node_attr['style'] = 'filled'
    graph.edge_attr['fontname'] = "Courier"
    graph.graph_attr['rankdir'] = "TB"
    graph.node_attr['fontsize'] = "10"

    for node in nodes:
        graph.node(node.get("id"), 
                   node.get("txt"), 
                   fillcolor=node.get("fillcolor"),
                   rank=node.get("rank"))

    edges = []
    for edge_dict in edge_dict_list:
        graph.edge(
            edge_dict.get("src"),
            edge_dict.get("dest"),
            label=edge_dict.get("label"),
            penwidth=edge_dict.get("penwidth")
        )

def gen_node(node_id, node_txt, fillcolor, rank):
    node = {}
    node["id"] = node_id
    node["txt"] = node_txt
    node["fillcolor"] = fillcolor
    node["rank"] = rank
    return node


In [None]:
def _total_num_of_caps(db):
    count_caps_q = "SELECT COUNT(*) FROM cap_info"
    return run_sql_query(db, count_caps_q)


In [None]:
def gen_full_graph(db, graph):
    get_bin_paths_q = "SELECT DISTINCT mmap_path FROM vm"
    path_list_json = run_sql_query(db, get_bin_paths_q)

    get_caps_q = "SELECT * FROM cap_info"
    caps = run_sql_query(db, get_caps_q)

    nodes = []
    edges = []

    for path_list in path_list_json:
        lib_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
        lib_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
        lib_start_addrs = run_sql_query(db, lib_start_q)
        lib_end_addrs = run_sql_query(db, lib_end_q)

        if (path_list[0] == "unknown" or
            path_list[0] == "Stack" or
            path_list[0] == "Guard"):
            path_label = path_list[0] + " (" + lib_start_addrs[0][0] + ")"
            fillcolor = "lightgrey"
            rank = "source"
        else:
            path_label = path_list[0]
            fillcolor = "lightblue"
            rank = "max"

        nodes.append(gen_node(path_label, path_label, fillcolor, rank))

        for cap in caps:
            cap_loc_addr = cap[0]
            cap_path = cap[1]
            cap_addr = cap[2]
            cap_perms = cap[3]
            cap_base = cap[4]
            cap_top = cap[5]

            penwidth_weight = 1

            for lib_addr_index in range(len(lib_start_addrs)):
                lib_start_addr = lib_start_addrs[lib_addr_index][0]
                lib_end_addr = lib_end_addrs[lib_addr_index][0]

                # Also remove the appended (.bss/got/plt) and compare
                if cap_addr >= lib_start_addr and \
                    cap_addr <= lib_end_addr and \
                    cap_path != path_list[0] and \
                    cap_path != path_list[0][:-6] and \
                    cap_path[:-6] != path_list[0] and \
                    cap_path[:-6] != path_list[0][:-6]:
                    
                    if (cap_path == "unknown" or
                        cap_path == "Stack" or
                        cap_path == "Guard"):

                        get_start_addr_q = "SELECT start_addr FROM vm WHERE mmap_path=\"" + cap_path + "\""
                        # Only interested in the first result?
                        start_addr_list_json = run_sql_query(db, get_start_addr_q)
                        cap_path_label = cap_path + " (" + start_addr_list_json[0][0] + ")"
                    else:
                        cap_path_label = cap_path

                    for props in edges:
                        if props.get("src") == cap_path_label and props.get("dest") == path_label and props.get("label") == cap_perms:
                            penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                            edges.remove(props)
                            break
                    edges.append({"src":cap_path_label, "dest":path_label, "label":cap_perms, "penwidth":str(penwidth_weight)})
    gen_records(graph, nodes, edges)
    

In [None]:
def show_comparts(db, graph):
    get_compart_id_q = "SELECT DISTINCT compart_id FROM vm"
    compart_ids = run_sql_query(db, get_compart_id_q)
    
    get_caps_q = "SELECT * FROM cap_info"
    caps = run_sql_query(db, get_caps_q)

    # locally_stored_caps = []
    # for cap in caps:
    #     caps_plus_compart_id = {}

    #     caps_plus_compart_id["cap_loc_addr"] = cap[0]
    #     caps_plus_compart_id["cap_path"] = cap[1]
    #     caps_plus_compart_id["cap_addr"] = cap[2]
    #     caps_plus_compart_id["cap_perms"] = cap[3]
    #     caps_plus_compart_id["cap_base"] = cap[4]
    #     caps_plus_compart_id["cap_top"] = cap[5]

    #     find_cap_compart_q = "SELECT DISTINCT compart_id FROM vm WHERE mmap_path LIKE '%" + caps_plus_compart_id["cap_path"] + "%'"
    #     cap_compart_json = run_sql_query(db, find_cap_compart_q)
    #     # only need the first compart_id as they should be all the same 
    #     caps_plus_compart_id["cap_compart_id"] = cap_compart_json[0][0]
    #     locally_stored_caps.append(caps_plus_compart_id)

    nodes = []
    edges = []  

    for compart_id in compart_ids:
        single_compart_id = compart_id[0]
        get_path_for_id_q = "SELECT DISTINCT mmap_path FROM vm WHERE compart_id=" + str(single_compart_id)
        paths = run_sql_query(db, get_path_for_id_q)

        path = ""
        count = 0
        for each_path in paths:
            path += os.path.basename(each_path[0])
            if (count != len(paths)-1):
                path += ", "
            count += 1
            
        if (single_compart_id < 0):
            fillcolor = "lightgrey"
            rank = "source"
        else:
            fillcolor = "lightblue"
            rank = "max"
        nodes.append(gen_node(
                        str(single_compart_id), 
                        str(single_compart_id)+": "+path, 
                        fillcolor,
                        rank))
        
        for path_list in paths:
            path_label = os.path.basename(path_list[0])

            print(str(compart_id) + ":" + path_label)

            lib_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
            lib_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
            lib_start_addrs = run_sql_query(db, lib_start_q)
            lib_end_addrs = run_sql_query(db, lib_end_q)    

            for cap in caps:
                cap_loc_addr = cap[0]
                cap_path = cap[1]
                cap_addr = cap[2]
                cap_perms = cap[3]
                cap_base = cap[4]
                cap_top = cap[5]
                
                penwidth_weight = 1

                for lib_addr_index in range(len(lib_start_addrs)):
                    lib_start_addr = lib_start_addrs[lib_addr_index][0]
                    lib_end_addr = lib_end_addrs[lib_addr_index][0]      

                    # Also remove the appended (.got/.plt) and compare
                    if cap_addr >= lib_start_addr and \
                        cap_addr <= lib_end_addr and \
                        cap_path != path_list[0] and \
                        cap_path != path_list[0][:-6] and \
                        cap_path[:-6] != path_list[0] and \
                        cap_path[:-6] != path_list[0][:-6]:

                        find_compart_id_q = "SELECT DISTINCT compart_id FROM vm WHERE mmap_path LIKE '%" + cap_path + "%'"
                        cap_path_compart_id_json = run_sql_query(db, find_compart_id_q)
                        # only need the first compart_id as they should be all the same 
                        cap_compart_id = cap_path_compart_id_json[0][0]
#                         print("cap_path_label: " + str(cap_compart_id) + " single_compart_id: " + str(single_compart_id) + " compart path: " + path)
               
                        for props in edges:
                            if props.get("src") == str(cap_compart_id) and props.get("dest") == str(single_compart_id) and props.get("label") == cap_perms:
                                penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                                edges.remove(props)
                                break
                        if (str(cap_compart_id) != str(single_compart_id)):
                            edges.append({"src":str(cap_compart_id), "dest":str(single_compart_id), "label":cap_perms, "penwidth":str(penwidth_weight)})
        
    gen_records(graph, nodes, edges)
# SELECT compart_id, mmap_path FROM vm INNER JOIN cap_info ON vm.mmap_path = cap_info.cap_loc_path WHERE cap_info.cap_addr < vm.end_addr AND cap_info.cap_addr > vm.start_addr

In [None]:
def show_comparts_has_caps(db, graph):
    get_compart_id_q = "SELECT distinct compart_id, mmap_path FROM vm"
    compart_ids = run_sql_query(db, get_compart_id_q) # returns an array of arrays, each with 2 elements: compart_id and mmap_path
    get_caps_q = "SELECT * FROM cap_info"
    caps = run_sql_query(db, get_caps_q)
    
    nodes = []
    edges = []

    for results in compart_ids:
        compart_id = results[0]
        mmap_path = os.path.basename(results[1])

        print(str(compart_id) + ":" + mmap_path)

        if (compart_id < 0):
            fillcolor = "lightgrey"
            rank = "source"
        else:
            fillcolor = "lightblue"
            rank = "max"
            
        nodes.append(gen_node(
                        str(compart_id), 
                        str(compart_id)+": "+mmap_path, 
                        fillcolor,
                        rank))

        lib_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + mmap_path + "%'"
        lib_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + mmap_path + "%'"
        lib_start_addrs = run_sql_query(db, lib_start_q)
        lib_end_addrs = run_sql_query(db, lib_end_q)    

        for cap in caps:
            cap_loc_addr = cap[0]
            cap_path = cap[1]
            cap_addr = cap[2]
            cap_perms = cap[3]
            cap_base = cap[4]
            cap_top = cap[5]
 
            for lib_addr_index in range(len(lib_start_addrs)):
                lib_start_addr = lib_start_addrs[lib_addr_index][0]
                lib_end_addr = lib_end_addrs[lib_addr_index][0]      

                # Also remove the appended (.got/.plt) and compare
                if cap_addr >= lib_start_addr and \
                    cap_addr <= lib_end_addr and \
                    cap_path != mmap_path and \
                    cap_path != mmap_path[:-6] and \
                    cap_path[:-6] != mmap_path and \
                    cap_path[:-6] != mmap_path[:-6]:

                    find_compart_id_q = "SELECT DISTINCT compart_id FROM vm WHERE mmap_path LIKE '%" + cap_path + "%'"
                    cap_path_compart_id_json = run_sql_query(db, find_compart_id_q)
                    # only need the first compart_id as they should be all the same 
                    cap_compart_id = cap_path_compart_id_json[0][0]
#                   print("cap_path_label: " + str(cap_compart_id) + " single_compart_id: " + str(single_compart_id) + " compart path: " + path)
               
                    for props in edges:
                        if props.get("src") == str(cap_compart_id) and props.get("dest") == str(compart_id):
                            # penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                            edges.remove(props)
                            break
                    if (str(cap_compart_id) != str(compart_id)):
                        #edges.append({"src":str(cap_compart_id), "dest":str(compart_id), "label":"", "penwidth":1})
                        edges.append({"src":str(cap_compart_id), "dest":str(compart_id), "label":""})
        
    gen_records(graph, nodes, edges)

In [None]:
CONFIG_FILE = 'chericat_cli.config'

def main(argv):
    dbpath="../../chericat_dbs/"
    dbname="test_nginx.db"
    if (os.path.isfile(CONFIG_FILE)):
        with open(CONFIG_FILE) as f:
            sys.argv = f.read().split(',')
    else:
        sys.argv = ['chericat_cli', '-d', dbpath+dbname, '-comp']

    parser = argparse.ArgumentParser(prog='chericat_cli')
    parser.add_argument(
        '-d', 
        help='The database to use for the queries', 
        required=True,
    )

    parser.add_argument(
        '-g',
        help='Generate full capability relationship in mmap graph',
        action='store_true',
    )
    
    parser.add_argument(
        '-r',
        help='Executes the SQL query on the provided db',
        nargs=1,
    )
    
    parser.add_argument(
        '-comp',
        help="Show compartments graph",
        action='store_true',
    )

    parser.add_argument(
        '-compc',
        help="Show compartments (has_capability relationship) graph",
        action='store_true',
    )

    args = parser.parse_args()

    if args.d:
        db = args.d

    if args.g:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename=dbname+'.libview_graph.gv')
        gen_full_graph(db, digraph)
        end = time.perf_counter()
        print("Full graph generation time taken: " + str(end-start) + "s")
        
        digraph.render(directory='graph-output', view=True)

    if args.r:
        print(run_sql_query(db, args.r[0]))

    if args.comp:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename=dbname+'.comp_graph.gv')
        show_comparts(db, digraph)
        end = time.perf_counter()
        print("Comparts graph generation time taken: " + str(end-start) + "s")
        digraph.render(directory='graph-output', view=True)

    if args.compc:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename=dbname+'.comp_no_perms_graph.gv')
        show_comparts_has_caps(db, digraph)
        end = time.perf_counter()
        print("Capability comparts graph generation time taken: " + str(end-start) + "s")
        digraph.render(directory='graph-output', view=True)
    
    return None

if __name__ == '__main__':
    main(sys.argv)
