In [1]:
import graphviz
import json
import sqlite3
import sys, os, argparse
import math
import time
import collections

In [2]:
def run_sql_query(db, query):
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(query)
    result_data = json.dumps(cur.fetchall())
    result_json = json.loads(result_data)
    return result_json

In [3]:
def gen_records(graph, nodes, edge_dict_list):
    graph.attr('node', shape='record', fontname='Courier')
    graph.graph_attr['rankdir'] = "LR"
    
    graph.node_attr['fixedsize'] = "true"
    graph.node_attr['fontname'] = "Impact"
    graph.node_attr['color'] = "lightgreen"
    graph.node_attr['fillcolor'] = "lightblue"
    graph.node_attr['style'] = 'filled'
    graph.node_attr['fontsize'] = "20"
    graph.node_attr['height'] = "1"
    graph.node_attr['width'] = "3"
    
    graph.edge_attr['fontname'] = "Impact"
    graph.edge_attr['fontsize'] = "10"

    for node in nodes:
        graph.node(node.get("id"), 
                   node.get("txt"), 
                   fillcolor=node.get("fillcolor"),
                   rank=node.get("rank"))

    edges = []
    for edge_dict in edge_dict_list:
        graph.edge(
            edge_dict.get("src"),
            edge_dict.get("dest"),
            label=edge_dict.get("label"),
            penwidth=edge_dict.get("penwidth")
        )

def gen_node(node_id, node_txt, fillcolor, rank):
    node = {}
    node["id"] = node_id
    node["txt"] = node_txt
    node["fillcolor"] = fillcolor
    node["rank"] = rank
    return node


In [4]:
def _total_num_of_caps(db):
    count_caps_q = "SELECT COUNT(*) FROM cap_info"
    return run_sql_query(db, count_caps_q)


In [5]:
import os

def gen_full_graph(db, graph):
    get_libs_q = "SELECT * FROM vm"
    libs_q_results = run_sql_query(db, get_libs_q)

    libs = []
    for libs_result in libs_q_results:
        lib = {}
        lib["start_addr"] = int(libs_result[0], base=16)
        lib["end_addr"] = int(libs_result[1], base=16)
        lib["mmap_path"] = os.path.basename(libs_result[2])
        if lib["mmap_path"].startswith("unknown"):
            lib["mmap_path"] = "heap"
        libs.append(lib)
        
    get_caps_q = "SELECT * FROM cap_info"
    caps_q_results = run_sql_query(db, get_caps_q)

    caps = []
    for caps_result in caps_q_results:
        cap = {}
        cap["cap_loc_addr"] = int(caps_result[0], base=16)
        cap["cap_loc_path"] = os.path.basename(caps_result[1])
        cap["cap_addr"] = int(caps_result[2], base=16)
        caps.append(cap)

    nodes = []
    edges = []
    src_dest = {}

    for src_lib in libs:
        if (src_lib["mmap_path"] == "unknown" or
            src_lib["mmap_path"] == "Stack" or
            src_lib["mmap_path"] == "Guard"):
            fillcolor = "lightgrey"
            rank = "source"
        else:
            fillcolor = "lightblue"
            rank = "max"

        nodes.append(gen_node(src_lib["mmap_path"], src_lib["mmap_path"], fillcolor, rank))

        for cap in caps:
            penwidth_weight = 1
            
            if cap["cap_loc_addr"] >= src_lib["start_addr"] and \
               cap["cap_loc_addr"] <= src_lib["end_addr"]:
                src = src_lib["mmap_path"]

                for dest_lib in libs:
                    if cap["cap_addr"] >= dest_lib["start_addr"] and \
                       cap["cap_addr"] <= dest_lib["end_addr"]:
                        dest = dest_lib["mmap_path"]

                        if src in src_dest:
                            src_dest[src].append(dest)
                        else:
                            src_dest[src] = [dest]
                        
                        for props in edges:
                            #if props.get("src") == cap_path_label and props.get("dest") == path_label and props.get("label") == cap_perms:
                            if props.get("src") == src and props.get("dest") == dest:
                                penwidth_weight = math.log((10**(float(props.get("penwidth"))) + 1), 10)
                                edges.remove(props)
                                break
                        edges.append({"src":src, "dest":dest, "label":"", "penwidth":str(penwidth_weight)})
 
    count_map = {}
    for idx,values in src_dest.items():
        for val in values:
            if str(idx)+":"+str(val) in count_map:
                count_map[str(idx)+":"+str(val)] += 1
            else :
                count_map[str(idx)+":"+str(val)] = 1

    print(count_map)
    gen_records(graph, nodes, edges)
    

In [6]:
CONFIG_FILE = 'chericat_cli.config'

def main(argv):
    dbpath="../../chericat_dbs/"
    dbname="server_no_c18n.db"
    if (os.path.isfile(CONFIG_FILE)):
        with open(CONFIG_FILE) as f:
            sys.argv = f.read().split(',')
    else:
        sys.argv = ['chericat_cli', '-d', dbpath+dbname, '-g']

    parser = argparse.ArgumentParser(prog='chericat_cli')
    parser.add_argument(
        '-d', 
        help='The database to use for the queries', 
        required=True,
    )

    parser.add_argument(
        '-g',
        help='Generate full capability relationship in mmap graph',
        action='store_true',
    )
    
    parser.add_argument(
        '-r',
        help='Executes the SQL query on the provided db',
        nargs=1,
    )
    
    parser.add_argument(
        '-comp',
        help="Show compartments graph",
        action='store_true',
    )

    parser.add_argument(
        '-compc',
        help="Show compartments (has_capability relationship) graph",
        action='store_true',
    )

    args = parser.parse_args()

    if args.d:
        db = args.d

    if args.g:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename=dbname+'.libview_graph.gv')
        gen_full_graph(db, digraph)
        end = time.perf_counter()
        print("Full graph generation time taken: " + str(end-start) + "s")
        
        digraph.render(directory='graph-output', view=True)

    if args.r:
        print(run_sql_query(db, args.r[0]))

    if args.comp:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename=dbname+'.comp_graph.gv')
        show_comparts(db, digraph)
        end = time.perf_counter()
        print("Comparts graph generation time taken: " + str(end-start) + "s")
        digraph.render(directory='graph-output', view=True)

    if args.compc:
        start = time.perf_counter()
        digraph = graphviz.Digraph('G', filename=dbname+'.comp_no_perms_graph.gv')
        show_comparts_has_caps(db, digraph)
        end = time.perf_counter()
        print("Capability comparts graph generation time taken: " + str(end-start) + "s")
        digraph.render(directory='graph-output', view=True)
    
    return None

if __name__ == '__main__':
    main(sys.argv)


{'server(.got):server(.plt)': 8, 'server(.got):server': 133, 'server(.got):libc.so.7(.plt)': 1, 'server(.got):libc.so.7': 2, 'server:Heap(others)': 7, 'server:server': 1, 'server:ld-elf.so.1': 1, 'server:libc.so.7(.plt)': 6, 'server:server(.plt)': 25, 'server:libsys.so.7(.plt)': 7, 'ld-elf.so.1(.got):ld-elf.so.1': 723, 'ld-elf.so.1(.got):ld-elf.so.1(.got)': 4, 'ld-elf.so.1:ld-elf.so.1': 79, 'ld-elf.so.1:Heap(others)': 32, 'ld-elf.so.1:ld-elf.so.1(.got)': 2, 'ld-elf.so.1:libthr.so.3': 4, 'ld-elf.so.1:libthr.so.3(.plt)': 9, 'Heap(others):Heap(others)': 967, 'Heap(others):libsys.so.7': 29, 'Heap(others):libc.so.7': 81, 'Heap(others):libthr.so.3': 48, 'Heap(others):ld-elf.so.1': 16, 'Heap(others):server': 26, 'Heap(others):libthr.so.3(.got)': 4, 'Heap(others):server(.got)': 3, 'Heap(others):server(.plt)': 2, 'Heap(others):libsys.so.7(.got)': 3, 'Heap(others):libc.so.7(.got)': 6, 'Heap(others):Guard': 1, 'Heap(others):Stack': 1, 'libthr.so.3(.got):libthr.so.3(.plt)': 152, 'libthr.so.3(.got)