In [None]:
from datasets import load_from_disk
from loguru import logger
from pathlib import Path
import random
import pickle
import zstd

We delete the output of the block above, as it will output a warning prompt containing identity information.

In [2]:
def decompress_data(b_str):
    return pickle.loads(zstd.decompress(b_str))
def compress_data(obj):
    return zstd.compress(pickle.dumps(obj))

In [3]:
raw_ds = load_from_disk("data/raw")

# Fix Jump Target

`fix_jump` is used to `rebase` instructions like `jmp`, making them a relative offset from the start of the sequence.

This approach was first proposed in `jTrans: Jump-Aware Transformer for Binary Code Similarity`.

In [4]:
def fix_jump(asm):
    code = asm["code"]
    jump_dict = asm["jump_dict"]
    for row_idx in jump_dict:
        target_row_idx = jump_dict[row_idx]["target"]
        token_idx = jump_dict[row_idx]["token_idx"]
        new_token = f"[INSTR{target_row_idx}]"
        code[row_idx][token_idx] = new_token
    code = [" ".join(row) for row in code]
    return code

# Establish a fine-grained mapping between source code and binaries

`fix_row` is one of the most important functions in `BinQuery`.

This function first divides the binary and source code into different `ranges`, where the `range` here is not the `snippet` mentioned in the paper. 

When `gcc` compiles the source code, not every line of code corresponds to an assembly instruction, and not every assembly instruction address can be mapped back to a line of source code. 

Therefore, we select `binary addresses` and `source lines` that have a mapping relationship as `anchors`. 

Based on these `anchors`, we divide all the source code and binary code into `ranges`, and then establish a `range`-level mapping to facilitate the subsequent establishment of `snippet`-level mapping.

**Please note** that `range` is merely a implementation detail and **is not** the mechanism mentioned in the paper.

In [5]:
def fix_row(row):
    asm = decompress_data(row["asm"])
    no_strip_asm = decompress_data(row["no_strip_asm"])
    src = row["source_code"].split("\n")

    source_file_line = decompress_data(row["addr_to_source_code_file_line"])
    src_line_start, src_line_end = source_file_line["line"]
    if len(asm["anchor_map"]) == 0 or len(no_strip_asm["anchor_map"]) == 0:
        logger.error(f"splitter is empty: {row}")
        return {
            "asm": b"",
            "no_strip_asm": b"",
            "src": b"",
            "asm_range_list": b"",
            "src_range_list": b"",
            "src_asm_range_map": b"",
        }

    def convert_to_range(anchor_idx_list, total):
        if anchor_idx_list[0] != 0:
            anchor_idx_list = [0] + anchor_idx_list
        if anchor_idx_list[-1] != total:
            anchor_idx_list = anchor_idx_list + [total]
        return [
            (anchor_idx_list[i], anchor_idx_list[i + 1])
            for i in range(len(anchor_idx_list) - 1)
        ]

    def align(asm_dict):
        asm_code = fix_jump(asm_dict)
        line_map = {}
        for addr, src_line_info in source_file_line["mapping"].items():
            if not addr in asm_dict["anchor_map"]:
                continue
            _, src_lino = src_line_info
            if src_lino < src_line_start or src_lino > src_line_end:
                continue
            src_lino -= src_line_start
            asm_lino = asm_dict["anchor_map"][addr]
            line_map.setdefault(src_lino, []).append(asm_lino)

        new_line_map = {}
        src_lino_list = list(line_map.keys())
        src_lino_list.sort()
        for src_lino in src_lino_list:
            new_asm_lino_list = []
            asm_lino_list = line_map[src_lino]
            for asm_lino in asm_lino_list:
                if asm_lino - 1 in asm_lino_list:
                    continue
                new_asm_lino_list.append(asm_lino)
            new_line_map[src_lino] = new_asm_lino_list

        line_map = new_line_map

        src_anchor_list = list(line_map.keys())
        asm_anchor_list = [
            asm_lino
            for asm_lino_list in line_map.values()
            for asm_lino in asm_lino_list
        ]
        asm_anchor_list = list(set(asm_anchor_list))
        asm_anchor_list.sort()
        src_anchor_list = list(set(src_anchor_list))
        src_anchor_list.sort()
        asm_range_list = convert_to_range(asm_anchor_list, len(asm_code))
        src_range_list = convert_to_range(src_anchor_list, len(src))
        asm_lino_to_range_idx = {r[0]: idx for idx, r in enumerate(asm_range_list)}
        src_lino_to_range_idx = {r[0]: idx for idx, r in enumerate(src_range_list)}
        src_asm_range_map = {
            src_lino_to_range_idx[src_lino]: [
                asm_lino_to_range_idx[asm_lino] for asm_lino in asm_lino_list
            ]
            for src_lino, asm_lino_list in line_map.items()
        }

        if 0 not in src_asm_range_map:
            src_asm_range_map[0] = [0]

        return asm_code, src_range_list, asm_range_list, src_asm_range_map

    result = {}
    asm_code, src_range_list, asm_range_list, src_asm_range_map = align(asm)
    no_strip_asm_code = fix_jump(no_strip_asm)
    result.update(
        {
            "asm": compress_data(asm_code),
            "no_strip_asm": compress_data(no_strip_asm_code),
            "src": compress_data(src),
            "asm_range_list": compress_data(asm_range_list),
            "src_range_list": compress_data(src_range_list),
            "src_asm_range_map": compress_data(src_asm_range_map),
        }
    )
    return result

In [6]:
ds = raw_ds.map(
    fix_row,
    remove_columns=["addr_to_source_code_file_line", "source_code"],
    num_proc=64,
)
ds = ds.filter(lambda x: x["asm"] != b"", num_proc=64)
ds.save_to_disk("data/line_matched")

Map (num_proc=64): 100%|██████████| 100/100 [00:00<00:00, 219.90 examples/s]
Filter (num_proc=64): 100%|██████████| 100/100 [00:00<00:00, 265.75 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 3318.83 examples/s]


# Display the established mapping

In [7]:
row = random.choice(ds)

In [8]:
asm = decompress_data(row["asm"])
src = decompress_data(row["src"])
asm_range_list = decompress_data(row["asm_range_list"])
src_range_list = decompress_data(row["src_range_list"])
src_asm_range_map = decompress_data(row["src_asm_range_map"])

In [9]:
print('=' * 10 + " " * 3 + "asm start" + " " * 3 + '=' * 10)
for row_idx, asm_row in enumerate(asm):
    print(f"[{str(row_idx).rjust(4)}] {asm_row}")
print('=' * 10 + " " * 3 + "asm   end" + " " * 3 + '=' * 10)

[   0] endbr64
[   1] push rbp
[   2] mov rbp rsp
[   3] sub rsp 0B0h
[   4] mov [rbp+var_A8] rdi
[   5] mov rax fs:28h
[   6] mov [rbp+var_8] rax
[   7] xor eax eax
[   8] lea rax [rbp+var_70]
[   9] mov rdi rax
[  10] call sub_8F8EA
[  11] mov rax [rbp+var_A8]
[  12] mov eax [rax+8]
[  13] lea rcx [rbp+buf]
[  14] mov edx 18h
[  15] mov rsi rcx
[  16] mov edi eax
[  17] call _read
[  18] mov [rbp+var_98] rax
[  19] cmp [rbp+var_98] 0
[  20] setnle al
[  21] test al al
[  22] jz [INSTR130]
[  23] mov rax [rbp+var_98]
[  24] cmp rax 17h
[  25] ja [INSTR36]
[  26] mov rax [rbp+var_A8]
[  27] mov rax [rax+10h]
[  28] mov rdx [rbp+var_98]
[  29] mov rcx rdx
[  30] mov rdx rax
[  31] lea rsi aMouseReadEvent
[  32] mov edi 4
[  33] mov eax 0
[  34] call sub_21B3E
[  35] jmp [INSTR129]
[  36] movzx eax [rbp+var_80]
[  37] movzx eax ax
[  38] cmp eax 1
[  39] jz [INSTR43]
[  40] cmp eax 2
[  41] jz [INSTR74]
[  42] jmp [INSTR122]
[  43] movzx eax [rbp+var_7E]
[  44] movzx eax ax
[  45] and ea

In [10]:
print('=' * 10 + " " * 3 + "src start" + " " * 3 + '=' * 10)
for row_idx, src_row in enumerate(src):
    print(f"[{str(row_idx).rjust(4)}] {src_row}")
print('=' * 10 + " " * 3 + "src   end" + " " * 3 + '=' * 10)

[   0]     virtual void callback()
[   1]     {        
[   2]         //ps2_mouse_event event;
[   3]         input_event event;
[   4]         HidMsg msg;
[   5]         ssize_t len;
[   6]                 
[   7]         while( ( len = read( fd, &event, sizeof( event ) ) ) > 0 )
[   8]         {
[   9]             if( len < sizeof( event ) )
[  10]             {
[  12]                 continue;
[  13]             }
[  14]             
[  15]             switch( event.type )
[  16]             {
[  17]                 case EV_KEY:
[  18]                     if( event.code & BTN_MOUSE )
[  19]                     {
[  20]                         msg.clear();
[  21]                         msg.device_type = CK_HID_DEV_MOUSE;
[  22]                         msg.device_num = num;
[  23]                         msg.eid = event.code - BTN_MOUSE;
[  24]                         msg.type = event.value ? CK_HID_BUTTON_DOWN : CK_HID_BUTTON_UP;
[  25]                         msg.idata[0] = event.

In [11]:
print("=" * 10 + " " * 3 + "binary-source-mapping" + " " * 3 + "=" * 10)
start = True
def range_to_str(r):
    if r[1] - r[0] == 1:
        return f"{r[0]}"
    else:
        return f"[{r[0]},{r[1]})"
for src_range_idx, asm_range_indices in src_asm_range_map.items():
    if start:
        start = False
    else:
        print('-' * (26 + len('binary-source-mapping')))
    src_range = src_range_list[src_range_idx]
    asm_ranges = [asm_range_list[asm_range_idx] for asm_range_idx in asm_range_indices]
    print(f"src: {range_to_str(src_range)}")
    print(f"asm: {', '.join([range_to_str(r) for r in asm_ranges])}")
print("=" * 10 + " " * 3 + "binary-source-mapping" + " " * 3 + "=" * 10)

src: [0,4)
asm: [4,8)
-----------------------------------------------
src: [4,7)
asm: [8,11)
-----------------------------------------------
src: [7,9)
asm: [11,23), [129,137)
-----------------------------------------------
src: [9,11)
asm: [23,26)
-----------------------------------------------
src: 11
asm: [26,35)
-----------------------------------------------
src: [12,15)
asm: [35,37)
-----------------------------------------------
src: [15,18)
asm: [37,43)
-----------------------------------------------
src: [18,20)
asm: [43,48)
-----------------------------------------------
src: 20
asm: [48,51)
-----------------------------------------------
src: 21
asm: 51
-----------------------------------------------
src: 22
asm: [52,55)
-----------------------------------------------
src: 23
asm: [55,60)
-----------------------------------------------
src: 24
asm: [60,67)
-----------------------------------------------
src: 25
asm: [67,70)
-----------------------------------------------
src