Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to generate our own pretrain dataset? #18

Open
RobinHan24 opened this issue Mar 11, 2022 · 4 comments
Open

How to generate our own pretrain dataset? #18

RobinHan24 opened this issue Mar 11, 2022 · 4 comments

Comments

@RobinHan24
Copy link

As mentioned in readme, I followed to run the script preprocess_pretrain_10k.py to generate data in data-bin/pretrain_10k, but how can I generate myown data which is in data-src/pretrain_10k, thanks a lot.

@RobinHan24
Copy link
Author

RobinHan24 commented Mar 14, 2022

I have finished the script which can generate the pretrain dataset. It works when I run it. But I am not sure if it's exactly right. Could you please help me to check it. Thanks a lot.

1、python command/pretrain/prepare_json.py in:data-raw/bin out:data-raw/funcbytes
2、python command/finetune/prepare_finetune_trace.py in:data-raw/funcbytes out:data-raw/functraces
3、python command/finetune/prepare_finetune_single.py in:data-raw/functraces out:data-src/pretrain
4、python command/pretrain/preprocess_pretrain.py in:data-src/pretrain out:data-bin/pretrain/
5、./command/pretrain/pretrain.sh

command/finetune/prepare_finetune_single.py is shown below.

@RobinHan24
Copy link
Author

import os
from random import choice
import random
import argparse
import json
import gc


class Options():

    def __init__(self):
        self.initialized = False
        self.parser = None
        self.args = None
        self.epoch_needed_list = [-1]
        # now in the trace file, there is 4 epochs of traces ,now only need the one with '#######'
        # now only one epoch
        self.output_filename_prefix_list = ['train', 'valid']
        self.output_filename_prefix_prob_list = None
        self.output_filename_inter_list = ['static', 'inst_pos_emb', 'op_pos_emb', 'byte1', 'byte2', 'byte3', 'byte4',
                                           'arch_emb']
        self.positive_flag_list = [True, False]
        self.positive_flag_prob_list = [0.5, 0.5]
        self.archs = ['arm', 'mips', 'x86', 'x86_64']
        self.opts = ['O0', 'O1', 'O2', 'O3', 'orig',
                     'bcfobf', 'cffobf', 'splitobf', 'subobf', 'acdobf', 'indibran', 'strcry', 'funcwra']
        self.optimizations = ['O0', 'O1', 'O2', 'O3']
        self.obfs = ['bcfobf', 'cffobf', 'splitobf', 'subobf', 'acdobf', 'indibran', 'strcry', 'funcwra']
        self.arch_to_optListdict = {}
        self.opt_to_archListdict = {}

    def initialize(self, parser):

        '''
        parser.add_argument( '-archs','--archs_wanted_list', action="extend", nargs="*", type=str, required=False,
                             help="archs we want", default=archs)
        '''
        parser.add_argument('-n', '--sample_num', type=int, required=False, help="the number of samples", default=200)
        parser.add_argument('-obf', '--only_obf', action='store_true', required=False, default=False)
        parser.add_argument('-optimization', '--only_optimization', action='store_true', required=False, default=False)
        parser.add_argument('-can_inter', '--valid_train_func_can_intersection_flag', action='store_true',
                            required=False,
                            help="the valid dataset and the training dataset can have  intersection", default=False)
        parser.add_argument('-c', '--train_test_ratio', type=float, required=False,
                            help="the ratio of training samples",
                            default=0.1)
        parser.add_argument('-newline', '--tokens_newline_number', type=int, required=False,
                            help="number of token new lines", default=512)
        parser.add_argument('-archs', '--archs_wanted_list', type=str, nargs='*', required=False,
                            help="archs wanted", default=self.archs)
        parser.add_argument('-opts', '--opts_wanted_list', type=str, nargs='*', required=False, default=self.opts)
        parser.add_argument('-arch_same', '--arch_must_same_flag', action='store_true', required=False, default=False)
        parser.add_argument('-opt_differ', '--opt_must_differ_flag', action='store_true', required=False, default=False)
        parser.add_argument('-opt_same', '--opt_must_same_flag', action='store_true', required=False, default=False)
        parser.add_argument('-i', '--functraces_folder_path', type=str, required=False, default='data-raw/functraces')
        parser.add_argument('-o', '--output_folder_path', type=str, required=False, default='data-src/pretrain')
        parser.add_argument('-trunc', '--tokens_truncate_flag', action='store_true', required=False, default=False)
        parser.add_argument('-minlen', '--trace_min_len', type=int, required=False, default=50)
        # if the value is 10, then positive and negative training data 1:9, and these 9 negative pairs share the same
        # left function with positive pair
        parser.add_argument('-training_cycle', '--training_cycle', type=int, required=False, default=10)
        parser.add_argument('-valid_cycle', '--valid_cycle', type=int, required=False, default=10)
        parser.add_argument('-bins', '--binfile_keyword_list_needed', type=str, nargs='*',
                            required=False, default=[])

        self.initialized = True
        self.parser = parser
        self.banned_function_list = ['skip_white', 'free_dir', 'parse_name', 'blake2b_increment_counter', 'main',
                                     'blake2b_compress', 'register_tm_clones', 'write_pending', 'millerrabin',
                                     'blake2b_set_lastblock', 'print_name', 'read_string', 'blake2b_init_param',
                                     'check', 'base64_decode', 'print_stuff', 'blake2b_init0', 'usage',
                                     'cleanup', 'frame_dummy', 'print_entry', 'print_user', 'print_stats',
                                     'base64_encode', 'base64url_encode', 'blake2b_init', 'deregister_tm_clones']

    def parse(self):

        if not self.initialized:  # check if it has been initialized
            parser = argparse.ArgumentParser()
            self.initialize(parser)

        self.args = self.parser.parse_args()
        self.output_filename_prefix_prob_list = [self.args.train_test_ratio, 1 - self.args.train_test_ratio]

        for arch in self.archs:
            self.arch_to_optListdict[arch] = []

        for opt in self.opts:
            self.opt_to_archListdict[opt] = []

        for arch_opt_name in os.listdir(self.args.functraces_folder_path):
            tmp_list = arch_opt_name.split('-')
            tmp_arch = arch_str_to_arch_dict[f'{tmp_list[0]}-{tmp_list[1]}']
            tmp_opt = f'{tmp_list[2]}'
            self.arch_to_optListdict[tmp_arch].append(tmp_opt)
            self.opt_to_archListdict[tmp_opt].append(tmp_arch)

        return self.args


class BiaryFileInfo(object):

    def __init__(self, arch=None, opt=None, func_name=None, proj_name=None, trace_path=None):
        """

        :param archs_opt_name: string, like 'x86-32-O3' or 'mips-32-O1'
                it would be transformed to the variable member of the object
        """

        self.arch = arch
        self.opt = opt
        self.func_name = func_name
        self.proj_name = proj_name
        self.trace_path = trace_path

    def __str__(self):
        return f'BiaryFileInfo:trace_file_info {self.trace_path}, arch {self.arch}, ' \
               f'opt {self.opt}, func_name {self.func_name}, proj_name {self.proj_name}'


def two_lists_intersection(list1, list2):
    """

    :param list1:
    :param list2:
    :return: list, intersection of lists
    """
    return list(set(list1).intersection(set(list2)))


def list1_minus_list2(list1, list2):
    return list(set(list1).difference(set(list2)))



def getAllFuncPath(options):
    AllFuncNameList = []
    for arch in options.args.archs_wanted_list:
        for opt in options.args.opts_wanted_list:
            arch_to_optListdict = options.arch_to_optListdict
            if opt  in arch_to_optListdict[arch] :

                arch_opt_name = f'{arch_to_arch_str_dict[arch]}-{opt}'

                # it should be 'functraces/arm-32-O3' or 'functraces/mips-32-O1' .etc
                arch_opt_folder_path = f'{options.args.functraces_folder_path}/{arch_opt_name}'
                #print (arch_opt_folder_path)
                for func_name in os.listdir(arch_opt_folder_path):
                    AllFuncNameList.append(arch_opt_folder_path + '/'+func_name+'$$'+arch+'$$'+opt)

                #print ('-------------AllFuncNameList--------------------',arch_opt_name)
                #print ( AllFuncNameList[-10:-1])
    #print (AllFuncNameList)
    return AllFuncNameList
def get_trace_file_info(func_folder_path_arch_opt, options=None):
    #print ('-------------------func_folder_path---------------------')
    path_arc_opt = func_folder_path_arch_opt.split('$$')
    func_folder_path = path_arc_opt[0]
    arch = path_arc_opt[1]
    opt = path_arc_opt[2]
    random_func_name = func_folder_path.split('/')[-1]

    trace_name_list = os.listdir(func_folder_path)
    #print(trace_name_list)
    if len(trace_name_list) == 0 :
        return None,None,None,None
 
    random_trace_name1 = choice(trace_name_list)
    random_trace_name2 = choice(trace_name_list)
    random_trace_name3 = choice(trace_name_list)
    random_trace_name4 = choice(trace_name_list)

    trace_path1 = f'{func_folder_path}/{random_trace_name1}'
    trace_path2 = f'{func_folder_path}/{random_trace_name2}'
    trace_path3 = f'{func_folder_path}/{random_trace_name3}'
    trace_path4 = f'{func_folder_path}/{random_trace_name4}'

    trace_file_info1 = BiaryFileInfo(arch, opt, random_func_name, random_trace_name1, trace_path1)
    trace_file_info2 = BiaryFileInfo(arch, opt, random_func_name, random_trace_name2, trace_path2)
    trace_file_info3 = BiaryFileInfo(arch, opt, random_func_name, random_trace_name3, trace_path3)
    trace_file_info4 = BiaryFileInfo(arch, opt, random_func_name, random_trace_name4, trace_path4)

    
    return trace_file_info1,trace_file_info2,trace_file_info3,trace_file_info4

    

def value_list2seq(value_list):
    """

    :param value_list: like ['########', '00510000', '########', '00510000' ...
    :return: like [['##', '##', '##', '##'], ['00', '51', '00', '00'], ...
    """
    # print(value_list)
    value_list_transpose = []
    for value in value_list:
        # print(len(value))
        value_list_transpose.append([value[i:i + 2] for i in range(0, len(value), 2)])
    return value_list_transpose


def value_to_four_byte_list(value_list):
    byte_sequence_list = value_list2seq(value_list)
    # print(byte_sequence_list)
    return [i[0] for i in byte_sequence_list], [i[1] for i in byte_sequence_list] \
        , [i[2] for i in byte_sequence_list], [i[3] for i in byte_sequence_list]


def trace_file_info_to_data_list(trace_file_info, options=None):
    """

    :param trace_file_info:
    :return: like [[code],[inst_emb],[inst_pos_emb],[arch_id],[byte1],[byte2],[byte3],[byte4],
                [arch,opt/obfuscation,proj_name,func_name]]
    """

    assert len(options.epoch_needed_list) == 1
    assert os.path.isfile(trace_file_info.trace_path)
    data_list = []
    # print(trace_file_info.trace_path)
    size = os.path.getsize(trace_file_info.trace_path)
    if size == 0:
        return None
    with open(trace_file_info.trace_path, 'r') as trace:
        trace_epoch_index = options.epoch_needed_list[0]
        # if the file is empty, error will happen

        trace_list = json.load(trace)
        # length is 4, index is 3 yes, index is 4 no
        if len(trace_list) <= trace_epoch_index:
            return None
        trace_epoch_list = trace_list[trace_epoch_index]
        data_list.append(trace_epoch_list[0])
        data_list.append(trace_epoch_list[2])
        data_list.append(trace_epoch_list[3])
        bytes_list = list(value_to_four_byte_list(trace_epoch_list[1]))
        data_list.append(bytes_list[0])
        data_list.append(bytes_list[1])
        data_list.append(bytes_list[2])
        data_list.append(bytes_list[3])
        data_list.append([arch_to_arch_data_dict[trace_file_info.arch], trace_file_info.opt,
                          trace_file_info.proj_name, trace_file_info.func_name])

    return data_list


def write_str_to_file(output_filename_prefix, output_filename_inter, 
                      str_to_write, write_type, options):
    final_data_folder = options.args.output_folder_path

    filename = f'{output_filename_prefix}.{output_filename_inter}'
    file_path = os.path.join(final_data_folder, filename)
    with open(file_path, write_type) as wf:
        wf.write(str_to_write)


def append_dict_to_outputfile_as_str(pair_dict, output_filename_prefix, options=None):

    f1_list = pair_dict['f1']  # input0

    output_filename_inter_list = options.output_filename_inter_list
    assert len(f1_list) == len(output_filename_inter_list)

    tokens_newline_number = options.args.tokens_newline_number
    tokens_truncate_flag = options.args.tokens_truncate_flag
    if not tokens_truncate_flag and len(f1_list[0]) > tokens_newline_number:
        print("f1 too long, discard")
        return None

    line_token_num1 = min(len(f1_list[0]), tokens_newline_number)

    if tokens_newline_number < len(f1_list[0]):
        print("truncate")


    if line_token_num1 < options.args.trace_min_len :
        print("Trace too short !! The length of two traces are", str(line_token_num1))
        return None

    assert line_token_num1 > 0 and line_token_num1 <= tokens_newline_number


    if tokens_truncate_flag:
        assert line_token_num1 != None
        f1_list = [arr[:line_token_num1] for arr in f1_list]
    
    for output_filename_inter_index in range(len(output_filename_inter_list) - 1):
        # write ['train', 'valid'] ['static', 'inst_emb', 'inst_pos_emb', 'byte1', 'byte2', 'byte3', 'byte4']
        # [input0 input1]
        output_filename_inter = output_filename_inter_list[output_filename_inter_index]
        assert line_token_num1 is not None
        assert len(f1_list[0]) == line_token_num1 
        # print(f1_list[output_filename_inter_index])
        # print(f1_list[output_filename_inter_index])
        str_to_write1 = " ".join([str(x) for x in f1_list[output_filename_inter_index]])
        write_str_to_file(output_filename_prefix, output_filename_inter, str_to_write1 + "\n", 'a', options)
       
    output_filename_inter = output_filename_inter_list[len(output_filename_inter_list) - 1]  # 'arch_emb'
    arch1 = f1_list[len(output_filename_inter_list) - 1][0]  # 1st element in like ["x64", "O2", "", ""]


    append_arch_list1 = [arch1 for i in range(line_token_num1)]

    assert append_arch_list1 is not None
    str_to_write1 = " ".join(append_arch_list1)

    write_str_to_file(output_filename_prefix, output_filename_inter,  str_to_write1 + "\n", 'a', options)

    return True


def generate5type(trace_file_info1, output_filename_prefix=None, options=None):

    result_dict = {}
    data_list1 = trace_file_info_to_data_list(trace_file_info1, options)

    if data_list1 is None:
        return None


    result_dict['f1'] = data_list1

    if append_dict_to_outputfile_as_str(result_dict, output_filename_prefix, options) is None:
        return None
    return True



        
def init_output_files(options):

    for output_filename_prefix in options.output_filename_prefix_list:
        for output_filename_inter in options.output_filename_inter_list:
            write_str_to_file(output_filename_prefix, output_filename_inter,  "", 'w', options)


arch_str_to_arch_dict = {'x86-32': 'x86',
                         'x86-64': 'x86_64',
                         'arm-32': 'arm',
                         # 'arm-64': Cs(CS_ARCH_ARM64, CS_MODE_ARM + CS_MODE_LITTLE_ENDIAN),
                         'mips-32': 'mips'}
# 'mips-64': Cs(CS_ARCH_MIPS, CS_MODE_MIPS64 + CS_MODE_BIG_ENDIAN)}

arch_to_arch_str_dict = {'x86': 'x86-32',
                         'x86_64': 'x86-64',
                         'arm': 'arm-32',
                         # 'arm-64': Cs(CS_ARCH_ARM64, CS_MODE_ARM + CS_MODE_LITTLE_ENDIAN),
                         'mips': 'mips-32'}

arch_to_arch_data_dict = {'x86': 'x86',
                          'x86_64': 'x64',
                          'arm': 'arm',
                          # 'arm-64': Cs(CS_ARCH_ARM64, CS_MODE_ARM + CS_MODE_LITTLE_ENDIAN),
                          'mips': 'mips'}


def data_split(full_list, ratio, shuffle=True):
    """
    split dataset: split the full_list randomly according to ratio
    :param full_list: datalist
    :param ratio:     
    :param shuffle:   
    :return:
    """
    n_total = len(full_list)
    offset = int(n_total * ratio)
    if n_total == 0 or offset < 1:
        return [], full_list
    if shuffle:
        random.shuffle(full_list)
    sublist_1 = full_list[:offset]
    sublist_2 = full_list[offset:]
    return sublist_1, sublist_2


if __name__ == '__main__':

    # parser
    optionss = Options()
    args = optionss.parse()
    if args.only_optimization:
        args.opts_wanted_list = optionss.optimizations
    if args.only_obf:
        args.opts_wanted_list = optionss.obfs
    # test argparse


    assert (set(optionss.args.archs_wanted_list) <= set(optionss.archs))
    assert (set(optionss.args.opts_wanted_list) <= set(optionss.opts))

    init_output_files(optionss)
    AllPathList = getAllFuncPath(optionss)
    trainList, validList= data_split(AllPathList,0.8)
    #print (len(AllPathList))
    print (len(trainList))
    
    for path in trainList:
        trace_file_info1, trace_file_info2,trace_file_info3,trace_file_info4= get_trace_file_info(path,optionss)
        if trace_file_info1 is not None :
            #write content to file
            tmp_flag = generate5type(trace_file_info1, 'train', optionss)
            if tmp_flag == True:
                tmp_flag2 = generate5type(trace_file_info2, 'train', optionss)
                tmp_flag3 = generate5type(trace_file_info3, 'train', optionss)
                tmp_flag4 = generate5type(trace_file_info4, 'train', optionss)
                print (tmp_flag,tmp_flag2,tmp_flag3,tmp_flag4)
    for path in validList:
        trace_file_info1, trace_file_info2,trace_file_info3,trace_file_info4= get_trace_file_info(path,optionss)
        if trace_file_info1 is not None :
            tmp_flag = generate5type(trace_file_info1, 'valid', optionss)
            if tmp_flag == True:
                tmp_flag2 = generate5type(trace_file_info2, 'valid', optionss)
                tmp_flag3 = generate5type(trace_file_info3, 'valid', optionss)
                tmp_flag4 = generate5type(trace_file_info4, 'valid', optionss)
                print (tmp_flag,tmp_flag2,tmp_flag3,tmp_flag4)

                



@peikexin9
Copy link
Member

Hi @RobinHan24 thanks for posting your scripts. Since you are pretraining, the byte1-4 needs to include real execution traces, not dummy traces only used in finetuning. While your generated data format might seem correct, it might not include the actual traces. This corresponds to your 2nd step "python command/finetune/prepare_finetune_trace.py", this is only preparing the dataset with dummy traces (where byte1-4 are mostly dummy values). If you want to generate actual traces, you need an emulator and really execute the code you collected in funcbytes (you may want to look at micro_trace/prepare_code_trace.py). This step is a bit not so straighforward to get it correct so that's why I provided pretrained model. But if you really want to generate your own dataset for pretraining, we can talk and I can walk through the steps with you.

@RobinHan24
Copy link
Author

@peikexin9 Thanks again. I'm curious how to collect so much vulnerability data in order to uncover vulnerabilities that have not been discovered in firmware images. Could you please share your experice or methods. Thank you very much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants