In [1]:
import graphviz
import json
import sqlite3
import sys, os, argparse
import math
import time

In [2]:
def run_sql_query(db, query):
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(query)
    result_data = json.dumps(cur.fetchall())
    result_json = json.loads(result_data)
    return result_json

In [3]:
def gen_records(graph, nodes, edge_dict_list):
    graph.attr('node', shape='record', fontname='Courier', size='6,6')
    graph.node_attr['fontname'] = "Courier"
    graph.node_attr['color'] = "lightgreen"
    graph.node_attr['fillcolor'] = "lightblue"
    graph.node_attr['style'] = 'filled'
    graph.edge_attr['fontname'] = "Courier"
    graph.graph_attr['rankdir'] = "TB"
    graph.node_attr['fontsize'] = "10"

    for node in nodes:
        graph.node(node.get("id"), 
                   node.get("txt"), 
                   fillcolor=node.get("fillcolor"),
                   rank=node.get("rank"))

    edges = []
    for edge_dict in edge_dict_list:
        graph.edge(
            edge_dict.get("src"),
            edge_dict.get("dest"),
            label=edge_dict.get("label"),
            penwidth=edge_dict.get("penwidth")
        )

def gen_node(node_id, node_txt, fillcolor, rank):
    node = {}
    node["id"] = node_id
    node["txt"] = node_txt
    node["fillcolor"] = fillcolor
    node["rank"] = rank
    return node


In [4]:
def _total_num_of_caps(db):
    count_caps_q = "SELECT COUNT(*) FROM cap_info"
    return run_sql_query(db, count_caps_q)


In [5]:
def gen_full_graph(db, graph):
    get_bin_paths_q = "SELECT DISTINCT mmap_path FROM vm"
    path_list_json = run_sql_query(db, get_bin_paths_q)

    get_caps_q = "SELECT * FROM cap_info"
    caps = run_sql_query(db, get_caps_q)

    nodes = []
    edges = []

    for path_list in path_list_json:
        lib_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
        lib_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
        lib_start_addrs = run_sql_query(db, lib_start_q)
        lib_end_addrs = run_sql_query(db, lib_end_q)

        if (path_list[0] == "unknown" or
            path_list[0] == "Stack" or
            path_list[0] == "Guard"):

#             get_start_addr_q = "SELECT start_addr FROM vm WHERE mmap_path=\"" + path_list[0] + "\""
#             # Only interested in the first result?
#             start_addr_list_json = run_sql_query(db, get_start_addr_q)
            path_label = path_list[0] + " (" + lib_start_addrs[0][0] + ")"
            fillcolor = "lightgrey"
            rank = "source"
        else:
            path_label = path_list[0]
            fillcolor = "lightblue"
            rank = "max"

        nodes.append(gen_node(path_label, path_label, fillcolor, rank))

        for cap in caps:
            cap_loc_addr = cap[0]
            cap_path = cap[1]
            cap_addr = cap[2]
            cap_perms = cap[3]
            cap_base = cap[4]
            cap_top = cap[5]

            penwidth_weight = 1

            for lib_addr_index in range(len(lib_start_addrs)):
                lib_start_addr = lib_start_addrs[lib_addr_index][0]
                lib_end_addr = lib_end_addrs[lib_addr_index][0]

                # Also remove the appended (.bss/got/plt) and compare
                if cap_addr >= lib_start_addr and \
                    cap_addr <= lib_end_addr and \
                    cap_path != path_list[0] and \
                    cap_path != path_list[0][:-6] and \
                    cap_path[:-6] != path_list[0] and \
                    cap_path[:-6] != path_list[0][:-6]:
                    
                    if (cap_path == "unknown" or
                        cap_path == "Stack" or
                        cap_path == "Guard"):

                        get_start_addr_q = "SELECT start_addr FROM vm WHERE mmap_path=\"" + cap_path + "\""
                        # Only interested in the first result?
                        start_addr_list_json = run_sql_query(db, get_start_addr_q)
                        cap_path_label = cap_path + " (" + start_addr_list_json[0][0] + ")"
                    else:
                        cap_path_label = cap_path

                    for props in edges:
                        if props.get("src") == cap_path_label and props.get("dest") == path_label and props.get("label") == cap_perms:
                            penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                            edges.remove(props)
                            break
                    edges.append({"src":cap_path_label, "dest":path_label, "label":cap_perms, "penwidth":str(penwidth_weight)})
    gen_records(graph, nodes, edges)
    

In [6]:
def show_caps_between_two_libs(db, lib1, lib2, graph):
    lib1_caps_q = "SELECT * FROM cap_info WHERE cap_loc_path LIKE '%" + str(lib1) + "%'"
    lib1_caps = run_sql_query(db, lib1_caps_q)

    lib2_caps_q = "SELECT * FROM cap_info WHERE cap_loc_path LIKE '%" + str(lib2) + "%'"
    lib2_caps = run_sql_query(db, lib2_caps_q)
                    
    lib1_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + str(lib1) + "%'"
    lib1_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + str(lib1) + "%'"
    lib1_start_addrs = run_sql_query(db, lib1_start_q)
    lib1_end_addrs = run_sql_query(db, lib1_end_q)

    lib2_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + str(lib2) + "%'"
    lib2_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + str(lib2) + "%'"
    lib2_start_addrs = run_sql_query(db, lib2_start_q)
    lib2_end_addrs = run_sql_query(db, lib2_end_q)

    # node showing the loaded libraries, with cap records pointing into it
    # We have all the information here at this point, just need to graph it

    # First we create the two library nodes
    nodes = []
    edges = []

    nodes.append(gen_node(lib1, lib1, "lightblue", "same"))
    nodes.append(gen_node(lib2, lib2, "pink", "same"))

    for cap1 in lib1_caps:
        cap_loc_addr = cap1[0]
        cap_path = cap1[1]
        cap_addr = cap1[2]
        cap_perms = cap1[3]
        cap_base = cap1[4]
        cap_top = cap1[5]

        penwidth_weight = 1;

        
        for lib2_addr_index in range(len(lib2_start_addrs)):
            lib2_start_addr = lib2_start_addrs[lib2_addr_index][0]
            lib2_end_addr = lib2_end_addrs[lib2_addr_index][0]

            # Test if cap1's pointer address is in lib2's address range
            if cap_addr >= lib2_start_addr and cap_addr <= lib2_end_addr:
                print("cap1 to lib2: " + cap_perms)
                for props in edges:
                    if props.get("src") == lib1 and props.get("dest") == lib2:
                        penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                        edges.remove(props)
                        break
                edges.append({"src":lib1, "dest":lib2, "label":cap_perms, "penwidth":str(penwidth_weight)})

    for cap2 in lib2_caps:
        cap_loc_addr = cap2[0]
        cap_path = cap2[1]
        cap_addr = cap2[2]
        cap_perms = cap2[3]
        cap_base = cap2[4]
        cap_top = cap2[5]

        for lib1_addr_index in range(len(lib1_start_addrs)):
            lib1_start_addr = lib1_start_addrs[lib1_addr_index][0]
            lib1_end_addr = lib1_end_addrs[lib1_addr_index][0]

            # Test if cap2's pointer address is in lib1's address range
            if cap_addr >= lib1_start_addr and cap_addr <= lib1_end_addr:
                print("cap2 to lib1: " + cap_perms)

                for props in edges:
                    if props.get("src") == lib2 and props.get("dest") == lib1 and props.get("label") == cap_perms:
                        penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                        edges.remove(props)
                        break
                edges.append({"src":lib2, "dest":lib1, "label":cap_perms, "penwidth":str(penwidth_weight)})

    gen_records(graph, nodes, edges)

In [7]:
def show_comparts(db, graph):
    get_compart_id_q = "SELECT DISTINCT compart_id FROM vm"
    compart_ids = run_sql_query(db, get_compart_id_q)
    
    get_caps_q = "SELECT * FROM cap_info"
    caps = run_sql_query(db, get_caps_q)

    nodes = []
    edges = []

    for compart_id in compart_ids:
        single_compart_id = compart_id[0]
        get_path_for_id_q = "SELECT DISTINCT mmap_path FROM vm WHERE compart_id=" + str(single_compart_id)
        paths = run_sql_query(db, get_path_for_id_q)

        path = ""
        count = 0
        for each_path in paths:
            path += each_path[0]
            if (count != len(paths)-1):
                path += ", "
            count += 1
            
        if (single_compart_id < 0):
            fillcolor = "lightgrey"
            rank = "source"
        else:
            fillcolor = "lightblue"
            rank = "max"
        nodes.append(gen_node(
                        str(single_compart_id), 
                        str(single_compart_id)+": "+path, 
                        fillcolor,
                        rank))
        
        for path_list in paths:
            path_label = path_list[0]

            lib_start_q = "SELECT start_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
            lib_end_q = "SELECT end_addr FROM vm WHERE mmap_path LIKE '%" + str(path_list[0]) + "%'"
            lib_start_addrs = run_sql_query(db, lib_start_q)
            lib_end_addrs = run_sql_query(db, lib_end_q)             

            for cap in caps:
                cap_loc_addr = cap[0]
                cap_path = cap[1]
                cap_addr = cap[2]
                cap_perms = cap[3]
                cap_base = cap[4]
                cap_top = cap[5]

                penwidth_weight = 1

                for lib_addr_index in range(len(lib_start_addrs)):
                    lib_start_addr = lib_start_addrs[lib_addr_index][0]
                    lib_end_addr = lib_end_addrs[lib_addr_index][0]

                    # Also remove the appended (.bss/got/plt) and compare
                    if cap_addr >= lib_start_addr and \
                        cap_addr <= lib_end_addr and \
                        cap_path != path_list[0] and \
                        cap_path != path_list[0][:-6] and \
                        cap_path[:-6] != path_list[0] and \
                        cap_path[:-6] != path_list[0][:-6]:
                    
#                         if (cap_path == "unknown" or
#                             cap_path == "Stack" or
#                             cap_path == "Guard"):

#                             get_start_addr_q = "SELECT start_addr FROM vm WHERE mmap_path=\"" + cap_path + "\""
#                             # Only interested in the first result?
#                             start_addr_list_json = run_sql_query(db, get_start_addr_q)
#                             cap_path_label = cap_path + " (" + start_addr_list_json[0][0] + ")"
#                         else:
#                             cap_path_label = cap_path

                        find_compart_id_q = "SELECT DISTINCT compart_id FROM vm WHERE mmap_path LIKE '%" + cap_path + "%'"
                        cap_path_compart_id_json = run_sql_query(db, find_compart_id_q)
                        # only need the first compart_id as they should be all the same 
                        cap_compart_id = cap_path_compart_id_json[0][0]
#                         print("cap_path_label: " + str(cap_compart_id) + " single_compart_id: " + str(single_compart_id) + " compart path: " + path)

                        for props in edges:
                            if props.get("src") == str(cap_compart_id) and props.get("dest") == str(single_compart_id) and props.get("label") == cap_perms:
                                penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                                edges.remove(props)
                                break
                        if (str(cap_compart_id) != str(single_compart_id)):
                            edges.append({"src":str(cap_compart_id), "dest":str(single_compart_id), "label":cap_perms, "penwidth":str(penwidth_weight)})
        
    gen_records(graph, nodes, edges)


In [9]:
CONFIG_FILE = 'chericat_cli.config'

def main(argv):
    if (os.path.isfile(CONFIG_FILE)):
        with open(CONFIG_FILE) as f:
            sys.argv = f.read().split(',')
    else:
         #sys.argv = ['chericat_cli', '-d', '/home/psjm3/chericat/trial.db', '-c', 'ld-elf', 'unknown']
          sys.argv = ['chericat_cli', '-d', '/home/psjm3/chericat/pp.db', '-l']

    parser = argparse.ArgumentParser(prog='chericat_cli')
    parser.add_argument(
        '-d', 
        help='The database to use for the queries', 
        required=True,
    )

    parser.add_argument(
        '-g',
        help='Generate full capability relationship in mmap graph',
        action='store_true',
    )
    
    parser.add_argument(
        '-r',
        help='Executes the SQL query on the provided db',
        nargs=1,
    )

    parser.add_argument(
        '-c',
        help="Show capabilities between two loaded libraries <libname 1> <libname 2>",
        nargs=2,
    )
    
    parser.add_argument(
        '-l',
        help="Show compartments graph",
        action='store_true',
    )

    args = parser.parse_args()

    if args.d:
        db = args.d

    if args.g:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename='graph_overview.gv')
        gen_full_graph(db, digraph)
        end = time.perf_counter()
        print("Full graph generation time taken: " + str(end-start) + "s")
        
        digraph.render(directory='graph-output', view=True)

    if args.r:
        print(run_sql_query(db, args.r[0]))

    if args.c:
        digraph = graphviz.Digraph('G', filename=args.c[0]+'_vs_'+args.c[1]+'.gv')
        show_caps_between_two_libs(db, args.c[0], args.c[1], digraph)
        digraph.render(directory='graph-output', view=True)
        
    if args.l:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename='comparts_graph.gv')
        show_comparts(db, digraph)
        end = time.perf_counter()
        print("Comparts graph generation time taken: " + str(end-start) + "s")
        
        digraph.render(directory='graph-output', view=True)
    
    return None

if __name__ == '__main__':
    main(sys.argv)


Comparts graph generation time taken: 0.016480879858136177s


FileNotFoundError: [Errno 2] No such file or directory: 'xdg-open'