# Chapter 2: Similar 

In this chapter, we learn how to judge the similarity of two functions and make a database to store the known functions 

* **Task 1:** Define the key used for match 
* **Task 2:** define some sql sentences such as create table and index, insert sql 
* **Task 3:** get the attributes of every function and construct the sql row
* **Task 4:** Create a database, index and insert the attributes with helper functions
* **Task 5:** test the database with similar program
    * **Task 5.1:** Define some functions to calculate the similar percentage of two functions  
    * **Task 5.2:** Find `several` most similar functions with the input func
    * **Task 5.3:** find the matches functions 
* **Task 6:** Save the program and close the project

In [None]:
# Task 1: Define the key used for match 

# the lower bound for numaddress
LOWER_BOUND_ADDR_NUM = 20

def filter_func(func: 'ghidra.program.model.listing.Function') -> bool:
    '''
    filter the thunk function (just one inst to call other function) 
    and short functions ( smaller than 20 addresses )
    '''
    return not func.isThunk() and func.getBody().getNumAddresses() > LOWER_BOUND_ADDR_NUM

def get_inst_key(func: 'ghidra.program.model.listing.Function') -> tuple:
    '''
    get the number of addresses and the mnemonic string (split by ,) of this function 
    '''
    code_units = func.getProgram().getListing().getCodeUnits(func.getBody(), True)
    # TODO: consider convert tuple to dict to avoid use index to access the value 
    return (int(func.body.numAddresses),",".join(code.getMnemonicString() for code in code_units)) 
    
def get_struct_graph_key(func: 'ghidra.program.model.listing.Function') -> tuple:  
    '''
    get the structure graph related attributes in this function
    such as blocks, edges, calls, jumps 
    '''
    # use this not flat_api.getMonitor() to avoid passing flat_api
    from ghidra.util.task import ConsoleTaskMonitor
    monitor = ConsoleTaskMonitor()
    from ghidra.program.model.block import BasicBlockModel
    block_model = BasicBlockModel(func.getProgram(), True)
    # all starts with 1 to prevent multiply zero 
    (num_blocks,num_edges,num_calls,num_jumps) = (1,1,1,1)
    for block in block_model.getCodeBlocksContaining(func.getBody(), monitor):
        num_blocks += 1
        num_edges += block.getNumDestinations(monitor)
        refs_ = block.getDestinations(monitor)
        while refs_.hasNext():
            ref_ = refs_.next()
            flow_type_ = ref_.getFlowType()
            if flow_type_.isCall():
                num_calls += 1
            elif flow_type_.isJump():
                num_jumps += 1
    # TODO: consider convert tuple to dict to avoid use index to access the value 
    return (num_blocks,num_edges,num_calls,num_jumps)

# TODO: add more keys 

In [None]:
# Task 2: define some sql sentences such as create table and index, insert sql 
import sqlite3 

FUNC_KEYS  = ['name','hash','numAddresses','mnemonics','block_num','edge_num','call_num','jump_num']
FUNC_TYPES = ['TEXT','INTEGER','INTEGER','TEXT','INTEGER','INTEGER','INTEGER','INTEGER']
FUNC_TABLE_NAME = 'func_table'

def sql_create_table(cursor: sqlite3.Cursor, keys: list, types: list, table_name: str):
    create_sql_ = f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY,' + ','.join(f'{k} {t}' for k,t in zip(keys,types)) + ');'
    print(create_sql_)
    cursor.execute(create_sql_)
    
def sql_create_index(cursor: sqlite3.Cursor, table_name: str, index: list, index_name: str):
    '''
    create index in sql 
    '''
    create_index_ = f'CREATE INDEX IF NOT EXISTS {index_name} on {table_name}(' + ','.join(item for item in index) + ');'
    print(create_index_)
    cursor.execute(create_index_)
    
def sql_insert(cursor : sqlite3.Cursor, keys: list, val: list, table_name: str):
    '''
    insert mutiple rows 
    '''
    # TODO: add the deduplication of database 
    insert_sql_ = f'INSERT INTO {table_name} (' + ','.join(item for item in keys) + ') VALUES (' + ','.join('?' for item in keys) + ')'
    cursor.executemany(insert_sql_, val)

In [None]:
# Reuse the function in Chapter 1
! pip install pyhidra > /dev/null
# import launcher
from pyhidra.launcher import PyhidraLauncher, GHIDRA_INSTALL_DIR

class HeadlessLoggingPyhidraLauncher(PyhidraLauncher):
    """
    Headless pyhidra launcher
    Slightly Modified from Pyhidra to allow the Ghidra log path to be set
    """

    def __init__(self, verbose=False, log_path=None):
        super().__init__(verbose)
        self.log_path = log_path

    def _launch(self):
        from pyhidra.launcher import _silence_java_output
        from ghidra.framework import Application, HeadlessGhidraApplicationConfiguration
        from java.io import File
        with _silence_java_output(not self.verbose, not self.verbose):
            config = HeadlessGhidraApplicationConfiguration()
            if self.log_path:
                log = File(self.log_path)
                config.setApplicationLogFile(log)
            Application.initializeApplication(self.layout, config)

In [None]:
# start the Launcher 
launcher = HeadlessLoggingPyhidraLauncher(verbose=True, log_path='./launch.log')
launcher.start()

In [None]:
# Reuse the project create or open in chapter 1 
# Necessary imports for ghidra project 
from ghidra.base.project import GhidraProject
from java.io import IOException
from pathlib import Path 

# Create Project Dir and name 
project_location = Path('./ghidra_project')
project_location.mkdir(exist_ok=True, parents=True)
project_name = "database_project"

# create or open project 
try:
    project = GhidraProject.openProject(project_location, project_name, True)
    print(f'Opened project: {project.project.name}')
except IOException:
    project = GhidraProject.createProject(project_location, project_name, False)
    print(f'Created project: {project.project.name}')

In [None]:
program_path = Path('./stm32x9i_ssl_client.elf')
program = project.importProgram(program_path)

# get the flat api 
from ghidra.program.flatapi import FlatProgramAPI
flat_api = FlatProgramAPI(program)
flat_api.analyzeAll(program)

In [None]:
# Task 3: get the attributes of every function and construct the sql row
from functools import reduce
from operator import mul

rows_ = []
for func_ in program.getListing().getFunctions(True):
    if filter_func(func_):
        row_ = (func_.getName(),)
        inst_ = get_inst_key(func_)
        graph = get_struct_graph_key(func_)
        # make sure the inst_[0] means the numAddress 
        hash_ = reduce(mul,(n for n in graph),1) * inst_[0]
        if hash_ >= 0xffffffff:
            print(f'WARNING: {func_.getName()} hash is a little long {hash_}')
        row_ += (hash_,) + inst_ + graph
        rows_.append(row_)

In [None]:
# Task 4: Create a database, index and insert the attributes with helper functions
import os 
DATABASE = './test_db.db'
os.remove(DATABASE)
conn = sqlite3.connect(DATABASE)
cursor = conn.cursor()

# create table and index 
sql_create_table(cursor,FUNC_KEYS,FUNC_TYPES,FUNC_TABLE_NAME)
sql_create_index(cursor,FUNC_TABLE_NAME,['hash'],'index_hash')
conn.commit()

# insert the rows 
sql_insert(cursor, FUNC_KEYS, rows_, FUNC_TABLE_NAME)
conn.commit()

# close the connect 
conn.close()

In [None]:
# Task 5: test the database with similar program
from difflib import SequenceMatcher

# the configurations 
TOLERANT_ = 0.1 # tolerant level: 90%
ACCEPT_ = 0.95 # accept level: >95%

# Task 5.1: Define some functions to calculate the similar percentage of two functions  
# the percentage of different attributes 
def diff_percentage(base: tuple, param: tuple) -> float:
    '''
    compute the percentage of different tuple
    base: been compared
    param: want to compare
    '''
    if len(base) != len(param):
        print("ERROR: different diff length")
        return 0
    return sum(abs(i-j)/i for (i,j) in zip(base,param)) / len(base)
    
# Task 5.2: Find [several] most similar functions with the input func
def compare_func(cursor: sqlite3.Cursor, func: 'ghidra.program.model.listing.Function') -> list:
    '''
    find the most match function and several candidate functions
    by compare keys we defined and the mnemonic strings  
    '''
    inst_ = get_inst_key(func_)
    graph = get_struct_graph_key(func_)
    # make sure the inst_[0] means the numAddress 
    hash_ = reduce(mul,(n for n in graph),1) * inst_[0]
    cursor.execute(f'SELECT * FROM {FUNC_TABLE_NAME} WHERE hash BETWEEN ? AND ? ORDER BY ABS(hash-?)', (round(hash_ * (1-TOLERANT_)), round(hash_ * (1+TOLERANT_)), hash_))
    results = cursor.fetchall()
    
    # filter the remain results 
    matches = []
    max_ = [(1,),0]
    for result_ in results:
        # compare the 
        same_ = 1 - diff_percentage(graph, result_[5:9])
        if same_ < ACCEPT_:
            continue
        match = SequenceMatcher(lambda x: x == ',', result_[4], inst_[1], autojunk=False)
        if match.quick_ratio() < ACCEPT_ or (ratio_ := match.ratio()) < ACCEPT_:
            continue
        if max_[1] < (same_ + ratio_):
            max_[1] = (same_ + ratio_) / 2
            max_[0] = result_
        matches.append([result_,(same_+ratio_)/2])
    return (max_, matches) 

In [None]:
test_program_path = Path('./stm32x9i_freertos_mpu.elf')
test_program = project.importProgram(test_program_path)

# get the flat api 
from ghidra.program.flatapi import FlatProgramAPI
flat_ = FlatProgramAPI(test_program)
flat_.analyzeAll(test_program)

In [None]:
# Task 5.3: find the matches functions 
import time
conn = sqlite3.connect(DATABASE)
cursor = conn.cursor()
start = time.time()
for func_ in test_program.getListing().getFunctions(True):
        if filter_func(func_):
            (max_, matches_) = compare_func(cursor, func_)
            if len(matches_) > 0:
                printf(f'Match {}')
            else:
                print(f'No match {func_.getName()}')
conn.close()

In [None]:
# Task 6: Save the program and close the project
project.saveAs(program, "/", program.getName(), True)
project.save(program)
project.close()

In [None]:
# TODO: add human defined database to classify the rtos, library and the crypto functions
# TOOD: add psuedo codes and decompiled c string comparison