In [1]:
import os,sys, json,re, pickle
import magic, hashlib,  traceback ,ntpath, collections ,lief
from capstone import *
from capstone.x86 import *
import torch.nn as nn
import lief
from elftools.elf.elffile import ELFFile
from transformers import AdamW,AutoTokenizer
from tqdm import tqdm  # for our progress bar
from sklearn.metrics import precision_recall_fscore_support , accuracy_score,f1_score, confusion_matrix,mean_squared_error, mean_absolute_error, r2_score
from numpy import *
from num2words import num2words
import pandas as pd

In [2]:
BIN_FILE_TYPE = 'PE' #or ELF
bin_path = '/home/raisul/DATA/temp/x86_pe_msvc_O2_static/'
bin_files = [os.path.join(bin_path, f) for f in os.listdir(bin_path) if f.endswith(".exe")][0:1]
ground_truth_path ='/home/raisul/DATA/temp/ghidra_x86_pe_msvc_O2_debug/'  
MODEL_SAVE_PATH= '/home/raisul/probabilistic_disassembly/models/'
EXPERIMENT_NAME = 'align'

MAX_SEQUENCE_LENGTH = 10


In [3]:

def get_ground_truth_ghidra(exe_path, text_section_offset , text_section_len):

    text_sextion_end = text_section_offset + text_section_len
    
    elf_file_name = os.path.basename(exe_path)
    ghidra_file_path = os.path.join(ground_truth_path, elf_file_name.split('.')[0]) + '.json'
    
    with open(ghidra_file_path, "r") as file:
        ghidra_data = json.load(file)

    ground_truth_offsets = list(ghidra_data.keys())

    ground_truth_offsets = [int(i) for i in ground_truth_offsets]
    ground_truth_offsets = [x for x in ground_truth_offsets if text_section_offset <= x <= text_sextion_end]
    ground_truth_offsets.sort()
    return ground_truth_offsets



def find_data_in_textsection(ground_truth_offsets , text_section_offset , text_section_len, offset_inst_dict):
    data_offsets = []
    for i in range(1, len(ground_truth_offsets)-1):
        distance = ground_truth_offsets[i+1] - ground_truth_offsets[i]

        inst_len = offset_inst_dict[ground_truth_offsets[i]].size 
        
        if distance!=inst_len:
            # print('offset_ranges[i]: ',ground_truth_offsets[i] , 'offset_ranges[i-1]: ',ground_truth_offsets[i-1], ' inst_len: ',inst_len  )
            # print(ground_truth_offsets[i],' ' ,hex(ground_truth_offsets[i]) , offset_inst_dict[ground_truth_offsets[i]], ' len',offset_inst_dict[ground_truth_offsets[i]].size )
            # print("\nByte GAP ###### ",distance ,' Missing bytes: ', distance - inst_len)
            
            for j in range( ground_truth_offsets[i] +inst_len , ground_truth_offsets[i+1]  ):
                data_offsets.append(j)
                # if offset_inst_dict[j]:
                #     print("# ",j, offset_inst_dict[j].mnemonic, offset_inst_dict[j].op_str , 'inst len:',offset_inst_dict[j].size )
                # else:
                #     print("# ",j, " invalid ")
            # print('\n')
        else:
            # print(ground_truth_offsets[i],' ', hex(ground_truth_offsets[i]) , offset_inst_dict[ground_truth_offsets[i]].mnemonic,offset_inst_dict[ground_truth_offsets[i]].op_str ,' len',offset_inst_dict[ground_truth_offsets[i]].size)
            pass
    return data_offsets
    

def linear_sweep(offset_inst , target_offset):
    inst_sequence = ''
    address_list = []
    
    current_offset = target_offset
    for q in range(MAX_SEQUENCE_LENGTH):

        if current_offset in offset_inst: #if end of text section
            current_instruction = offset_inst[current_offset]
            if current_instruction is None:
                return  None
                
            current_offset = current_offset + current_instruction.size
            inst_sequence+= str( hex(current_instruction.address)) +" "+ current_instruction.mnemonic +' '+ current_instruction.op_str+ ' ; ' 
            address_list.append(current_instruction.address)
            
            if current_instruction.mnemonic in ["ret", "jmp"]: #break linear sweep
                break
                

    return inst_sequence, address_list
    

In [17]:




SEQUENCES = []
LABELS     = []

for bin_file_path in bin_files:

    
    md = Cs(CS_ARCH_X86, CS_MODE_64)
    md.detail = True
    offset_inst = {}

    
    with open(bin_file_path, 'rb') as f:

        try:
            if BIN_FILE_TYPE == "ELF":
                elffile = ELFFile(f)
                textSection = elffile.get_section_by_name('.text').data()
                text_section_offset = elffile.get_section_by_name('.text')['sh_offset']
              
            elif BIN_FILE_TYPE == "PE":

                        
                pe_file = lief.parse(bin_file_path)
                text_section = pe_file.get_section(".text")
                text_section_offset = text_section.pointerto_raw_data
                textSection = bytes(text_section.content)
                
            ground_truth_offsets = get_ground_truth_ghidra(bin_file_path, text_section_offset , len(textSection))
            
        except Exception as e:
            print("An error occurred:", e ,bin_file_path)
            continue

    inst_sizes = {}
    for byte_index in range(len(textSection)):
        try:    

            instruction = next(md.disasm(textSection[byte_index: byte_index+15 ], text_section_offset + byte_index ), None)
            offset_inst[text_section_offset+byte_index] = instruction
            inst_sizes [text_section_offset+byte_index] = instruction.size if instruction else None
            
            # if instruction:
            #     print("%d:\t%s\t%s _\t%x" %(int(instruction.address), instruction.mnemonic, instruction.op_str, instruction.size))
            # else:
            #     print("%d:\t%s " % (text_section_offset + byte_index  , 'invalid instruction') )

            
            

        except Exception as e:
            print(traceback.print_exc() )
            print(e)

    
    
    offset_inst_dict = collections.OrderedDict(sorted(offset_inst.items()))

    DATA_OFFSETS = find_data_in_textsection(ground_truth_offsets , text_section_offset , len(textSection) , offset_inst)


    code_boundary = text_section_offset+len(textSection)

    # code_boundary = text_section_offset + 200
    
    boundary_occur_count = {i: 0 for i in range(text_section_offset, code_boundary)}

    short_range=30
    
    for byte_offset in range(text_section_offset, code_boundary):
        if byte_offset in offset_inst_dict: #not starting from invalid
            
            current_instruction = offset_inst_dict[byte_offset]
            if current_instruction == None:
                continue
            instruction_end_byte = byte_offset + current_instruction.size
            if instruction_end_byte>=code_boundary:
                break
            for _ in range(short_range):
                
                boundary_occur_count [instruction_end_byte ] +=1
                current_instruction = offset_inst_dict [instruction_end_byte]
                if current_instruction is None:
                    break
                instruction_end_byte += current_instruction.size
                
                if instruction_end_byte>=code_boundary:
                    break

    
    for offset in boundary_occur_count:
        label = '#' if offset in ground_truth_offsets else ' '
        if offset_inst_dict[offset]:
            inst_str = (offset_inst_dict[offset].mnemonic if offset_inst_dict[offset] else None )+( offset_inst_dict[offset].op_str if offset_inst_dict[offset] else '')
            inst_str = "{:<25}".format(inst_str) 
        else:
            inst_str = None
        print(label,f"{offset:<6}",' : ', f"{boundary_occur_count[offset]:<6} : "    , inst_str, '         len :',offset_inst_dict[offset].size if offset_inst_dict[offset] else '' )
        # print(label,' ',f"{offset:<8}",' : ', f"{boundary_occur_count[offset]:<6} : "  )

    print('x x '*100)

  1024    :  0      :  int3                               len : 1
  1025    :  1      :  int3                               len : 1
  1026    :  2      :  int3                               len : 1
  1027    :  3      :  int3                               len : 1
  1028    :  4      :  int3                               len : 1
# 1029    :  5      :  jmp0xf4c                           len : 5
  1030    :  0      :  oreax, dword ptr [rax]             len : 3
  1031    :  0      :  oreax, dword ptr [rax]             len : 2
  1032    :  0      :  addbyte ptr [rax], al              len : 2
  1033    :  2      :  addcl, ch                          len : 2
# 1034    :  7      :  jmp0x1dc8                          len : 5
  1035    :  3      :  movecx, 0xe9000019                 len : 5
  1036    :  0      :  sbbdword ptr [rax], eax            len : 2
  1037    :  0      :  addbyte ptr [rax], al              len : 2
  1038    :  1      :  addcl, ch                          len : 2
# 1039    

In [5]:
#jupyter nbconvert --to script data_pipe.ipynb
# accelerate launch data_pipe.py > log.txt