In [1]:
import os,sys, json,re, pickle
import magic, hashlib,  traceback ,ntpath, collections ,lief
from capstone import *
from capstone.x86 import *
import torch.nn as nn
import lief
from elftools.elf.elffile import ELFFile
from transformers import AdamW,AutoTokenizer
from tqdm import tqdm  # for our progress bar
from sklearn.metrics import precision_recall_fscore_support , accuracy_score,f1_score, confusion_matrix,mean_squared_error, mean_absolute_error, r2_score
from numpy import *
from num2words import num2words
import pandas as pd
from collections import defaultdict

In [2]:
BIN_FILE_TYPE = 'PE' #or ELF
bin_path = '/home/raisul/DATA/temp/x86_pe_msvc_O2_static/'
bin_files = [os.path.join(bin_path, f) for f in os.listdir(bin_path) if f.endswith(".exe")][0:1]
ground_truth_path ='/home/raisul/DATA/temp/ghidra_x86_pe_msvc_O2_debug/'  
MODEL_SAVE_PATH= '/home/raisul/probabilistic_disassembly/models/'
EXPERIMENT_NAME = 'align'



In [3]:

def get_ground_truth_ghidra(exe_path, text_section_offset , text_section_len):

    text_sextion_end = text_section_offset + text_section_len
    
    elf_file_name = os.path.basename(exe_path)
    ghidra_file_path = os.path.join(ground_truth_path, elf_file_name.split('.')[0]) + '.json'
    
    with open(ghidra_file_path, "r") as file:
        ghidra_data = json.load(file)

    ground_truth_offsets = list(ghidra_data.keys())

    ground_truth_offsets = [int(i) for i in ground_truth_offsets]
    ground_truth_offsets = [x for x in ground_truth_offsets if text_section_offset <= x <= text_sextion_end]
    ground_truth_offsets.sort()
    return ground_truth_offsets



def find_data_in_textsection(ground_truth_offsets , text_section_offset , text_section_len, offset_inst_dict):
    data_offsets = []
    for i in range(1, len(ground_truth_offsets)-1):
        distance = ground_truth_offsets[i+1] - ground_truth_offsets[i]

        inst_len = offset_inst_dict[ground_truth_offsets[i]].size 
        
        if distance!=inst_len:
            # print('offset_ranges[i]: ',ground_truth_offsets[i] , 'offset_ranges[i-1]: ',ground_truth_offsets[i-1], ' inst_len: ',inst_len  )
            # print(ground_truth_offsets[i],' ' ,hex(ground_truth_offsets[i]) , offset_inst_dict[ground_truth_offsets[i]], ' len',offset_inst_dict[ground_truth_offsets[i]].size )
            # print("\nByte GAP ###### ",distance ,' Missing bytes: ', distance - inst_len)
            
            for j in range( ground_truth_offsets[i] +inst_len , ground_truth_offsets[i+1]  ):
                data_offsets.append(j)
                # if offset_inst_dict[j]:
                #     print("# ",j, offset_inst_dict[j].mnemonic, offset_inst_dict[j].op_str , 'inst len:',offset_inst_dict[j].size )
                # else:
                #     print("# ",j, " invalid ")
            # print('\n')
        else:
            # print(ground_truth_offsets[i],' ', hex(ground_truth_offsets[i]) , offset_inst_dict[ground_truth_offsets[i]].mnemonic,offset_inst_dict[ground_truth_offsets[i]].op_str ,' len',offset_inst_dict[ground_truth_offsets[i]].size)
            pass
    return data_offsets
    

def linear_sweep(offset_inst , target_offset):
    inst_sequence = ''
    address_list = []
    
    current_offset = target_offset
    for q in range(MAX_SEQUENCE_LENGTH):

        if current_offset in offset_inst: #if end of text section
            current_instruction = offset_inst[current_offset]
            if current_instruction is None:
                return  None
                
            current_offset = current_offset + current_instruction.size
            inst_sequence+= str( hex(current_instruction.address)) +" "+ current_instruction.mnemonic +' '+ current_instruction.op_str+ ' ; ' 
            address_list.append(current_instruction.address)
            
            if current_instruction.mnemonic in ["ret", "jmp"]: #break linear sweep
                break
                

    return inst_sequence, address_list
    

In [4]:





for bin_file_path in bin_files:

    
    md = Cs(CS_ARCH_X86, CS_MODE_64)
    md.detail = True
    offset_inst = {}

    
    with open(bin_file_path, 'rb') as f:

        try:
            if BIN_FILE_TYPE == "ELF":
                elffile = ELFFile(f)
                textSection = elffile.get_section_by_name('.text').data()
                text_section_offset = elffile.get_section_by_name('.text')['sh_offset']
              
            elif BIN_FILE_TYPE == "PE":

                        
                pe_file = lief.parse(bin_file_path)
                text_section = pe_file.get_section(".text")
                text_section_offset = text_section.pointerto_raw_data
                textSection = bytes(text_section.content)
                
            ground_truth_offsets = get_ground_truth_ghidra(bin_file_path, text_section_offset , len(textSection))
            
        except Exception as e:
            print("An error occurred:", e ,bin_file_path)
            continue

    inst_sizes = {}
    for byte_index in range(len(textSection)):
        try:    

            instruction = next(md.disasm(textSection[byte_index: byte_index+15 ], text_section_offset + byte_index ), None)
            offset_inst[text_section_offset+byte_index] = instruction
            inst_sizes [text_section_offset+byte_index] = instruction.size if instruction else None
            
            # if instruction:
            #     print("%d:\t%s\t%s _\t%x" %(int(instruction.address), instruction.mnemonic, instruction.op_str, instruction.size))
            # else:
            #     print("%d:\t%s " % (text_section_offset + byte_index  , 'invalid instruction') )

            
            

        except Exception as e:
            print(traceback.print_exc() )
            print(e)

    
    
    offset_inst_dict = collections.OrderedDict(sorted(offset_inst.items()))

    DATA_OFFSETS = find_data_in_textsection(ground_truth_offsets , text_section_offset , len(textSection) , offset_inst)


    code_boundary = text_section_offset+len(textSection)

disasm  = offset_inst_dict

In [5]:
min_offset, max_offset = list(disasm.keys())[0] , list(disasm.keys())[-1] 
print(min_offset, max_offset )

1024 22736


In [6]:


def _compute_occlusion(disasm):
    """ Identify overlapping instructions and remove """
    occlusion = defaultdict(list)
    valid_instructions = set()

    for offset, details in disasm.items():
        if details!= None:
            for i in range(offset + 1, offset + details.size):
                occlusion[i].append(offset)

    # fix nahid
    covered = set()
    for offset in sorted(disasm.keys()):
        if offset in covered:
            # print(f"Skipping {offset} due to occlusion")
            continue  # Skip if another instruction already claimed this byte



        valid_instructions.add(offset)
        for i in range(offset, offset + disasm[offset].size):
            covered.add(i)  # Mark all bytes of this instruction as covered

    # print(f"Final valid instructions after occlusion: {sorted(valid_instructions)}")
    return occlusion, valid_instructions
occlusion_space, valid_instructions =_compute_occlusion(disasm)


In [7]:
# CONTROL_GROUPS = {
#     CS_GRP_JUMP,
#     CS_GRP_CALL,
#     CS_GRP_RET,
#     CS_GRP_IRET,
# }

# for key,val in disasm.items():
#     # print(val.groups)
#     if val == None:
#         continue
#     for group in val.groups:
#         if group in CONTROL_GROUPS:
#             print(f"0x{val.address:x}: {val.mnemonic} {val.op_str}")

In [8]:
# CONTROL_GROUPS ={"CALL", "COND_BR", "UNCOND_BR", "RET"}
CONTROL_GROUPS = {
    CS_GRP_JUMP,
    CS_GRP_CALL,
    CS_GRP_RET,
    CS_GRP_IRET,
}


def _compute_destinations(disasm):
    """ Compute successor addresses (CFG) and ensure function epilogues are correctly identified. """
    dests, preds = {}, defaultdict(list)
    last_offset = list(disasm.keys())[-1]
    first_offset = list(disasm.keys())[0]

    for offset, details in disasm.items():
        if details==None:
            continue
        inst_str = details.mnemonic +' ' + details.op_str
        next_offset = offset + details.size



        if not set(details.groups) & CONTROL_GROUPS:
            # Default fallthrough for non-control flow instructions
            if next_offset <= last_offset:
                dests[offset] = [next_offset]
                # preds[next_offset].append(offset)
            else:
                dests[offset] = []
        else: #control instruction
            #unconditional jump
            if details.id == X86_INS_JMP and details.operands and details.operands[0].type == CS_OP_IMM:
                 # Unconditional jump
                op_value = details.operands[0].imm
                if op_value>=first_offset and op_value<=last_offset:
                    dests[offset] = [op_value]
                    # preds[op_value].append(offset)
            
            # elif "COND_BR" in details.groups or "CALL" in details.groups:
            elif (CS_GRP_JUMP in details.groups or CS_GRP_CALL in details.groups) :
                if details.operands and details.operands[0].type == CS_OP_IMM:
                    jump_target = details.operands[0].imm
    
                    if next_offset<=last_offset:
                        dests[offset] = [next_offset]
                        # preds[next_offset].append(offset)
                                     
                    if jump_target>=first_offset and jump_target<=last_offset and jump_target!=next_offset:
                        if offset in dests:
                            dests[offset].append(jump_target)
                        else:
                            dests[offset] = [jump_target]
        
            else:
                # print('>>>>>  ',offset, ' : ' ,inst_str)
                dests[offset] = None

        if offset in dests:
            if dests[offset] is not None:
                for target in dests[offset]:
                    preds[target].append(offset)

    return dests, preds

cfg, preds = _compute_destinations(disasm)
cfg = dict(sorted(cfg.items()))
preds = dict(sorted(preds.items()))

def has_duplicates(lst):
    return len(lst) != len(set(lst))

for offset ,inst in disasm.items():
    if inst:
        print(offset," : ", hex(offset), ' ' ,inst.mnemonic +' ' + inst.op_str , '   ',inst.size)
        if offset in cfg:
            print('cfg ',cfg[offset])
        if offset in preds:
            print('pred ',preds[offset])

                



1024  :  0x400   int3      1
cfg  [1025]
1025  :  0x401   int3      1
cfg  [1026]
pred  [1024]
1026  :  0x402   int3      1
cfg  [1027]
pred  [1025]
1027  :  0x403   int3      1
cfg  [1028]
pred  [1026]
1028  :  0x404   int3      1
cfg  [1029]
pred  [1027]
1029  :  0x405   jmp 0xf4c     5
cfg  [3916]
pred  [1028]
1030  :  0x406   or eax, dword ptr [rax]     3
cfg  [1033]
1031  :  0x407   or eax, dword ptr [rax]     2
cfg  [1033]
1032  :  0x408   add byte ptr [rax], al     2
cfg  [1034]
1033  :  0x409   add cl, ch     2
cfg  [1035]
pred  [1030, 1031]
1034  :  0x40a   jmp 0x1dc8     5
cfg  [7624]
pred  [1032, 7488]
1035  :  0x40b   mov ecx, 0xe9000019     5
cfg  [1040]
pred  [1033]
1036  :  0x40c   sbb dword ptr [rax], eax     2
cfg  [1038]
1037  :  0x40d   add byte ptr [rax], al     2
cfg  [1039]
1038  :  0x40e   add cl, ch     2
cfg  [1040]
pred  [1036]
1039  :  0x40f   jmp 0x810     5
cfg  [2064]
pred  [1037, 2112, 2593, 6920]
1040  :  0x410   cld      1
cfg  [1041]
pred  [1035, 1038]

In [9]:
RH = defaultdict(set)
H = {}
res, data_prob,  = {}, {}
cfg, preds
occlusion_space, valid_instructions
min_offset, max_offset
BOTTOM = None

for offset in range(min_offset, max_offset + 1):

        if disasm[offset] is  None:
            data_prob[offset] =  1.0
        else:
            data_prob[offset] = BOTTOM
        H[offset], RH[offset] = BOTTOM, set()




In [11]:
BRANCH_GROUPS = {
    CS_GRP_CALL,       # Function call instruction
    CS_GRP_JUMP       # Conditional and unconditional branches
}

def _hint_one(offset, prev_list, disasm):

    """ Implements Control Flow Convergence hint. """


    print(offset)
    branches = [prev for prev in prev_list if set(disasm[prev].groups) & BRANCH_GROUPS]
    
    if disasm[offset]:
        print(offset , ' : ', hex(offset),' ',disasm[offset].mnemonic +' ' + disasm[offset].op_str )
    print(branches)
    for branch in branches:
        RH[branch].add(("1rel" if disasm[branch].size == 2 else "1near", offset))
        print(("1rel" if disasm[branch].size == 2 else "1near", offset))



def _hint_two(offset, disasm):

    """ Implements Control Flow Crossing hint. """

    if not set(disasm[offset].groups) & BRANCH_GROUPS:
        return
    
    inst2_offset = offset
    inst2_size = disasm[inst2_offset].size
    inst3_offset =  inst2_offset + inst2_size
    
    branches = [prev for prev in prev_list if set(disasm[prev].groups) & BRANCH_GROUPS]
    if disasm[offset]:
        print(offset , ' : ', hex(offset),' ',disasm[offset].mnemonic +' ' + disasm[offset].op_str  , ' ' , disasm[offset].size)
    
    if branches:
        for size, label in [(2, "2rel"), (5, "2near")]:
            prev_offset = offset - size

            if disasm[prev_offset] is not None:
                if prev_offset in disasm and CS_GRP_JUMP in disasm[prev_offset].groups and disasm[prev_offset].mnemonic == 'jmp':
                    RH[prev_offset].add((label, offset))
                    print('->    ',prev_offset , ' : p: ', label, offset)
                    for branch in branches:
                        RH[branch].add((label, offset))
                        print('->    ' ,branch , ' : b:',label, offset)


# def _hint_three(offset, prev_list, disasm):

#     """ Implements Register Define-Use Relation hint. """
#     if offset not in disasm:
#         return
#     if disasm[offset] is None:
#         return
#     for prev in prev_list:
#         prev_reg_write = set(disasm[prev]['regs_write'])
#         curr_reg_read = set(disasm[offset]['regs_read'])
#         if prev_reg_write & curr_reg_read:
#             RH[prev].add(("3orig", offset))
#             RH[offset].add(("3orig", prev))





for offset in disasm:
    # if offset>1200:
    #     break
    
    # if offset in preds:
    #     _hint_one(offset, preds[offset], disasm)
    
    _hint_two(offset, disasm)
    # _hint_three(offset, prev_list, disasm)

1025
1025  :  0x401   int3 
[]
1026
1026  :  0x402   int3 
[]
1027
1027  :  0x403   int3 
[]
1028
1028  :  0x404   int3 
[]
1029
1029  :  0x405   jmp 0xf4c
[]
1033
1033  :  0x409   add cl, ch
[]
1034
1034  :  0x40a   jmp 0x1dc8
[7488]
('1near', 1034)
1035
1035  :  0x40b   mov ecx, 0xe9000019
[]
1038
1038  :  0x40e   add cl, ch
[]
1039
1039  :  0x40f   jmp 0x810
[2112, 2593, 6920]
('1near', 1039)
('1near', 1039)
('1near', 1039)
1040
1040  :  0x410   cld 
[]
1041
1041  :  0x411   add eax, dword ptr [rax]
[]
1043
1043  :  0x413   add cl, ch
[]
1044
1044  :  0x414   jmp 0x1ed8
[]
1045
1045  :  0x415   mov edi, 0xe900001a
[]
1048
1048  :  0x418   add cl, ch
[]
1049
1049  :  0x419   jmp 0x1568
[]
1050
1050  :  0x41a   adc qword ptr [rax], rax
[]
1053
1053  :  0x41d   add cl, ch
[]
1054
1054  :  0x41e   jmp 0x1af8
[6929]
('1near', 1054)
1055
[]
1059
1059  :  0x423   jmp 0x2684
[]
1060
1060  :  0x424   pop rsp
[]
1061
1061  :  0x425   and al, byte ptr [rax]
[]
1063
1063  :  0x427   add cl, ch


In [None]:
9677  :  0x25cd   add bh, bh
[9674, 9675]