In [None]:
# import all the required libraries
from sklearn.ensemble import RandomForestClassifier
import json
import multiprocessing
from multiprocessing import Pool, set_start_method
from multiprocessing.dummy import Pool as ThreadPool
import os
import random
import signal
import sqlite3
from tqdm import tqdm
import pandas as pd
import re
import copy
from joblib import dump, load
import gc
import networkx as nx
import re
import numpy as np
import angr, csv, sys, statistics
import logging


In [None]:

# Create a logger for angr
logger = logging.getLogger('angr')
logger.setLevel(logging.ERROR)

logger.propagate = False

vex_logger = logging.getLogger('angr.analyses.propagator.engine_vex.SimEnginePropagatorVEX')
vex_logger.setLevel(logging.CRITICAL)

class IgnoreSpecificErrorFilter(logging.Filter):
    def filter(self, record):
        return "SimEnginePropagatorVEX | Unsupported" not in record.getMessage()


class IgnoreWarningsFilter(logging.Filter):
    def filter(self, record):
        return record.levelno < logging.WARNING

# Add the filters to the logger
logger.addFilter(IgnoreSpecificErrorFilter())
logger.addFilter(IgnoreWarningsFilter())




In [None]:
p_bytes = [10, 25, 7, 8] # payload size(bytes)
blockSize = 125

trans = ['mov', 'movabs', 'push', 'pop', 'lb', 'lbu', 'lh', 'lw', 'sb', 'sh', 'sw', 'ldc', 'lea', 'restore', 'lswi', 'sts', 'usp', 'srs', 'pea', 'lui', 'lhu', 'rdcycle', 'rdtime', 'rdinstret']
cal = ['add', 'sub', 'inc', 'xor', 'sar', 'addi', 'addiu', 'addu', 'and', 'ldr', 'andi', 'nor', 'or', 'ori', 'subu', 'xori', 'div', 'divu', 'mfhi', 'mflo', 'mthi', 'mtlo', 'mult', 'multu', 'sll', 'sllv', 'sra', 'srav', 'srl', 'srlv', 'bic', 'xnor', 'not', 'eor', 'asr', 'fabs', 'abs', 'mac', 'neg', 'cmp', 'test', 'slti', 'slt', 'sltu', 'sltui', 'sltiu', 'cmn', 'fcmp', 'dcbi', 'tas', 'btst', 'cbw', 'cwde', 'cdqe', 'cdq', 'slli', 'srli', 'srai', 'auipc', 'adc', 'sbb']
ctl = ['jmp', 'jz', 'jnz', 'jne', 'je', 'call', 'jr', 'beq', 'bge', 'bgeu', 'bgez', 'bgezal', 'bgtz', 'blez', 'blt', 'bltu', 'bltz', 'bltzal', 'bne', 'break', 'j', 'jal', 'jalr', 'mfc0', 'mtc0', 'syscall', 'leave', 'hvc', 'svc', 'hlt', 'arpl', 'sys', 'ti', 'trap', 'ret', 'retn', 'bl', 'bicc', 'bclr', 'bsrf', 'rte', 'wait', 'fwait', 'wfe', 'ecall', 'ebreak', 'jb', 'jbe']


In [None]:

def extract_feature(file_name):
    p = angr.Project(file_name, load_options={'auto_load_libs': False})
    # Generate a static CFG
    cfg = p.analyses.CFGFast()

    # Generate a dynamic CFG
    #cfg = p.analyses.CFGEmulated(keep_state=True) # involves execution
    G = cfg.graph
    G_undirected = G.to_undirected()

    # ---------------------------------------------------------
    # nodes & edges
    nodes = G.number_of_nodes()
    edges = G.number_of_edges()
    # ---------------------------------------------------------

    # degree
    idegree = {d[0]:d[1] for d in G.in_degree()}
    odegree = {d[0]:d[1] for d in G.out_degree()}

#         # normalized
    norm_in_degree = {_:idegree[_]/sum(idegree.values()) for _ in idegree}
    norm_out_degree = {_:odegree[_]/sum(odegree.values()) for _ in odegree}

#         # mean
    in_degree = np.mean([_ for _ in norm_in_degree.values()])
    out_degree = np.mean([_ for _ in norm_out_degree.values()])


    # density
    density = nx.density(G)

    # closeness_centrality
    closeness_centrality = np.mean(list(nx.closeness_centrality(G).values()))

    # betweeness_centrality
    betweeness_centrality = np.mean(list(nx.betweenness_centrality(G).values()))
    # connected_components
    connected_components = nx.number_connected_components(G_undirected)

#         ---------------------------------------------------------

    # shortest_path
    short_path = dict(nx.all_pairs_shortest_path(G))
    short_path_value = {}

    for i in short_path:
        temp = {}
        for j in short_path[i]:
            if i != j:
                temp[j] = len(short_path[i][j])
        short_path_value[i] = temp

#         ---------------------------------------------------------
#         diameter and radius
    sp = []
    for _ in short_path_value:
        sp.extend([i for i in short_path_value[_].values()])

    diameter = max(sp)
    radius = min(sp)
#         ---------------------------------------------------------
# collecting opocde features from instructions

    instr_list = {'trans': 0, 'cal': 0, 'ctl': 0}

    regex = re.compile('\t\w+\t')
    instruction = []
    block_num = nodes
    func_size = []

    for n in G.nodes(data=True):
        try:
            block_split = p.factory.block(n[0].function_address).capstone.insns
            func_size.append(len(block_split))
            for __ in block_split:
                instruction.append(regex.findall(str(__))[0][1:-1])
        except Exception as e:
            pass
            #print(e)


    for _ in instruction:
        if _ in trans:
            instr_list['trans'] += 1
        elif _ in cal:
            instr_list['cal'] += 1
        elif _ in ctl:
            instr_list['ctl'] += 1

    # total instruction count
    total_trans = instr_list['trans']
    total_cal = instr_list['cal']
    total_ctl = instr_list['ctl']

    # Avg. instruction count
    for _ in instr_list.keys():
        instr_list[_] /= block_num
    avg_trans = instr_list['trans']
    avg_cal = instr_list['cal']
    avg_ctl = instr_list['ctl']
    # ---------------------------------------------------------

    avg_block = edges / block_num
    avg_block_size = statistics.mean(func_size)

    # ---------------------------------------------------------      

    return [nodes, edges, out_degree, in_degree, density, closeness_centrality, betweeness_centrality, connected_components, diameter, radius, total_trans, total_cal, total_ctl, avg_trans, avg_cal, avg_ctl, avg_block, avg_block_size]



In [None]:
# payload injection into the binary

def check_avail_bytes(file):
    os.system(f'cp {file} ./test/{file.split(os.sep)[-1]}.test')
    result = os.popen(f'./elfinjector/build/elfinjector ./test/{file.split(os.sep)[-1]}.test ./elfinjector/build/payload_a').read()
    os.system(f'rm ./test/{file.split(os.sep)[-1]}.test')
    msg = re.findall('\(.*?available\)', result)
    if msg == []:
        return -1
        
    available_bytes = int(msg[0].split()[0][1:])
    
    del result, msg
    gc.collect()
    return available_bytes    
    
def extend(fn):
    os.system(f'./tpi {fn}')
    tmp = check_avail_bytes(fn)
    if tmp == -1:
        return -1
    else:
        print("file name:", (fn.split(os.sep)[-1]), "\nAvailable bytes:", tmp)
        return 0
    
def add_payload(fn, pn, N):
    ab = check_avail_bytes(fn)
    if ab == -1:
        return -1
    if ab <= (p_bytes[(ord(pn)-ord('a'))]*N):
        res = extend(fn)
    # test add payload
    for i in range(N):
        os.system(f'./elfinjector/build/elfinjector {fn} ./elfinjector/build/payload_{pn}')
        

def gen_AE(fn, pn, N): # file_name, payload_name, payload_num
    if pn == 0:
        add_payload(fn, 'a', N)
    elif pn == 1:
        add_payload(fn, 'b', N)
    elif pn == 2:
        add_payload(fn, 'c', N)
    elif pn == 3:
        add_payload(fn, 'd', N)



In [None]:
# Adversarial examples generation

class Attack:
    def __init__(self, file, model, scaler_path):
        self.file_name = file
        self.model = model
        self.scaler = load(scaler_path)  

    def attack(self):
        ori_feature = extract_feature(self.file_name)
        scaled_feature = self.scaler.transform(pd.DataFrame([ori_feature]))  
        ori_proba = self.model.predict_proba(scaled_feature)[0][0]  # benign probability
        print("Original Probability: ", ori_proba)
        attack_seq = []

        proba = ori_proba
        feature = ori_feature
        
        max_scale = 150  # maximum stagnation allowed
        scale = 1
        T = 0 
        flag = 0
        Thd = 1000 
        while proba < 0.5: 
            max_proba = proba 
            max_seq = copy.deepcopy(attack_seq)
            mi = check_payload_num(max_seq)  

            # ------------------------------------------------------------------------------------------------
            
            for i in range(4): 
                if i != mi: 
                    os.system(f'cp {self.file_name} {self.file_name}.bak') 
                    r = extend(f'{self.file_name}.bak') 
                    if attack_seq != []: 
                        for j in attack_seq: 
                            gen_AE(f'{self.file_name}.bak', j[0], j[1]) 
    
                    test_seq = copy.deepcopy(attack_seq)
                    test_seq.append((i, (int(blockSize/p_bytes[i]) * scale))) 
    
                    gen_AE(f'{self.file_name}.bak', test_seq[-1][0], test_seq[-1][1])
    
                    feature = extract_feature(f'{self.file_name}.bak')
                    scaled_feature = self.scaler.transform(pd.DataFrame([feature]))  # Scale the feature
                    os.system(f'rm {self.file_name}.bak')
    
                    proba_a = self.model.predict_proba(scaled_feature)[0][0]
    
                    if proba_a > (get_thd(max_proba, Thd)):
                        max_proba = proba_a
                        max_seq = copy.deepcopy(test_seq)
                        print("Max sequence", max_seq, "Max proba", max_proba )

                    gc.collect()
            # -----------------------------------------------------------------------------------------------------------------
            T += 1
            if Thd > 0: 
                Thd -= 1 
            if max_seq == attack_seq: 
                if scale >= max_scale: 
                    break
                flag += 1 
                scale += flag 
                continue
            else:            
                flag = 0
                scale = 1    
                attack_seq = copy.deepcopy(max_seq)
                proba = max_proba


        #for j in attack_seq:
        #    os.system(f'cp {self.file_name} ./adv_samples{self.file_name}') 
        #    r = extend(f'./adv_samples{self.file_name}') 
        #    gen_AE(f'./adv_samples{self.file_name}', j[0], j[1]) 
                
        if proba < 0.5:
            label = 1 
            attack_num = [0, 0, 0, 0]
            used_bytes = 0
            seq = "" 
            for i in attack_seq:
                used_bytes += i[1] * p_bytes[i[0]]
                attack_num[i[0]] += i[1]
                seq += str(i[0])
        else:
            label = 0  # benign 
            attack_num = [0, 0, 0, 0]
            used_bytes = 0
            seq = "" 
            for i in attack_seq:
                used_bytes += i[1] * p_bytes[i[0]]
                attack_num[i[0]] += i[1]
                seq += str(i[0])

        return {'label': label, 'used_bytes': used_bytes, 'attack_num': attack_num, 'iter': T, 'attack_seq': seq}

    def close(self):
        pass

def get_thd(x, Thd):
    return x + np.random.uniform(low=-0.00002, high=0.00005) * Thd   # adjust the values

# avoid overusing one payload
def check_payload_num(seq):
    if seq == []:
        return 100
    pn = [0 for i in range(4)]
    for i in seq:
        if i[0] == 0:
            pn[0] += i[1]
        elif i[0] == 1:
            pn[1] += i[1]
        elif i[0] == 2:
            pn[2] += i[1]
        elif i[0] == 3:
            pn[3] += i[1]
    for i in pn:
        if (i/sum(pn)) > 0.5:
            return pn.index(i)    


In [None]:

class DatabaseFactory:

    def __init__(self, train_db, db_name, root_path, scaler_path):
        self.train_db = train_db
        self.db_name = db_name
        self.root_path = root_path
        self.scaler_path = scaler_path
        self.model = self.load_model()

    @staticmethod
    def worker(item):
        DatabaseFactory.analyze_file(item)
        return 0

    @staticmethod
    def insert_in_db(db_name, pool_sem, result, filename):
        path = filename.split(os.sep)
        pool_sem.acquire()
        conn = sqlite3.connect(db_name)
        cur = conn.cursor()
        cur.execute('''INSERT INTO functions VALUES (?,?,?,?,?,?,?,?,?,?,?)''', (None,            # id
                                                                 path[-1],     # file_name
                                                                 result['label'],
                                                                 result['avail_bytes'],
                                                                 result['used_bytes'],
                                                                 result['attack_num'][0],
                                                                 result['attack_num'][1],
                                                                 result['attack_num'][2],
                                                                 result['attack_num'][3],
                                                                 result['iter'],
                                                                 result['attack_seq']
                                                                 ))
        conn.commit()
        conn.close()
        pool_sem.release()
        
    @staticmethod
    def analyze_file(item):
        global pool_sem
        os.setpgrp()
        
        filename = item[0]
        db = item[1]
        model = item[2]
        scaler_path = item[3]
        
        avail_bytes = check_avail_bytes(filename)
        if avail_bytes == -1:
            return -1
        
        analyzer = Attack(filename, model, scaler_path)
        pool = ThreadPool(1)
        res = pool.apply_async(analyzer.attack)
        
        result = res.get(9000) 
        
        result['file_size'] = os.path.getsize(filename)
        result['avail_bytes'] = avail_bytes
        
        DatabaseFactory.insert_in_db(db, pool_sem, result, filename)

        analyzer.close()
        del result, analyzer, res
        gc.collect()
        return 0
    
    def load_model(self):
        model = load('./detector/models/RF.joblib')  # Load the model
        return model

    def create_db(self):
        print('Database creation...')
        conn = sqlite3.connect(self.db_name)
        conn.execute(''' CREATE TABLE  IF NOT EXISTS functions (id INTEGER PRIMARY KEY, 
                                                                file_name text,
                                                                label numeric,
                                                                avail_bytes numeric,
                                                                used_bytes numeric,
                                                                p_loop numeric,
                                                                p_block numeric,
                                                                p_trans numeric,
                                                                p_arith numeric, 
                                                                T numeric,
                                                                attack_seq text
                                                                );''')
        conn.commit()
        conn.close()

    def scan_for_file(self, start):
        file_list = []
        directories = os.listdir(start)
        for item in directories:
            item = os.path.join(start, item)
            if os.path.isdir(item):
                file_list.extend(self.scan_for_file(item + os.sep))
            elif os.path.isfile(item):
                file_list.append(item)
        return file_list

    def remove_override(self, file_list):
        conn = sqlite3.connect(self.train_db) 
        cur = conn.cursor()
        q = cur.execute('''SELECT filename FROM functions''') 
        names = q.fetchall()
        
        conn = sqlite3.connect(self.db_name)
        cur = conn.cursor()
        q = cur.execute('''SELECT file_name FROM functions''')
        names.extend(q.fetchall())
        cleaned_file_list = []
        names = [_[0] for _ in names]
        for f in file_list:
            if not(f.split(os.sep)[-1] in names):
                cleaned_file_list.append(f)

        return cleaned_file_list

    def build_db(self):
        global pool_sem

        pool_sem = multiprocessing.BoundedSemaphore(value=1)

        self.create_db()
        file_list = self.scan_for_file(self.root_path)
        print('Found ' + str(len(file_list)) + ' during the scan')
        file_list = self.remove_override(file_list)
        print('Find ' + str(len(file_list)) + ' files to analyze')
        random.shuffle(file_list)

        t_args = [(f, self.db_name, self.model, self.scaler_path) for f in file_list]

        p = Pool(processes=4, maxtasksperchild=4)
        for _ in tqdm(p.imap_unordered(DatabaseFactory.worker, t_args), total=len(file_list)):
            pass

        p.close()
        p.join()


In [None]:


if __name__ == '__main__':
    scaler_path = "./detector/models/scaler.joblib"
    train_db = './train.db' 
    db = './results.db' 
    file_dir = './mal_samples/' 
    factory = DatabaseFactory(train_db, db, file_dir,scaler_path)
    factory.build_db()

