Extract x86 basic blocks from arbitrary binaries using objdump so we can collect a bunch of data for pretraining.

In [9]:
import re
import subprocess
from dataclasses import dataclass
from pathlib import Path

In [24]:
@dataclass
class BasicBlock:
    hex_bytes: str
    instructions: list[str]
    address: int

In [25]:
# Instructions that end a basic block (any change in control flow)
BLOCK_TERMINATORS = frozenset(
    [
        # Jumps
        "jmp",
        "jmpq",
        "je",
        "jne",
        "jz",
        "jnz",
        "jg",
        "jge",
        "jl",
        "jle",
        "ja",
        "jae",
        "jb",
        "jbe",
        "js",
        "jns",
        "jo",
        "jno",
        "jp",
        "jnp",
        "jpe",
        "jpo",
        "jcxz",
        "jecxz",
        "jrcxz",
        # Loops
        "loop",
        "loope",
        "loopne",
        "loopz",
        "loopnz",
        # Calls/returns
        "call",
        "callq",
        "ret",
        "retq",
        "retf",
        "iret",
        "iretq",
        # System
        "syscall",
        "sysenter",
        "sysexit",
        "int",
        "int3",
        "into",
        # Other
        "hlt",
        "ud2",
    ]
)

In [26]:
# Cursed regex to parse objdump output: "  401000:\t48 89 e5\tmov    %rsp,%rbp"
# Address, then hex bytes, then mnemonic, then maybe operands
INSTRUCTION_RE = re.compile(
    r"^\s*([0-9a-f]+):\s+"
    r"((?:[0-9a-f]{2}\s)+)\s+"
    r"(\S+)"
    r"(?:\s+(.*))?$",
    re.IGNORECASE,
)

LABEL_RE = re.compile(r"^[0-9a-f]+\s+<(.+)>:$", re.IGNORECASE)

In [27]:
def extract_basic_blocks(binary_path: str | Path) -> list[BasicBlock]:
    cmd = ["objdump", "-d", "-M", "att", str(binary_path)]
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)

    blocks = []
    current_hex, current_instrs, current_addr = [], [], None

    def finish_block():
        nonlocal current_hex, current_instrs, current_addr
        if current_instrs:
            blocks.append(
                BasicBlock(
                    hex_bytes="".join(current_hex),
                    instructions=current_instrs,
                    address=current_addr,
                )
            )
        current_hex, current_instrs, current_addr = [], [], None

    for line in result.stdout.splitlines():
        if LABEL_RE.match(line):
            finish_block()
            continue

        match = INSTRUCTION_RE.match(line)
        if not match:
            continue

        addr_str, hex_bytes, mnemonic, operands = match.groups()
        addr = int(addr_str, 16)
        hex_clean = hex_bytes.replace(" ", "").strip()
        instr = f"{mnemonic}\t{operands.strip()}" if operands else mnemonic

        if current_addr is None:
            current_addr = addr

        current_hex.append(hex_clean)
        current_instrs.append(instr)

        if mnemonic.lower() in BLOCK_TERMINATORS:
            finish_block()

    finish_block()
    return blocks

In [28]:
blocks = extract_basic_blocks("/bin/ls")
print(f"Extracted {len(blocks)} basic blocks")

Extracted 4823 basic blocks


In [32]:
for block in blocks[:3]:
    print()
    print(block.hex_bytes)
    for instr in block.instructions:
        print(instr)


f30f1efa4883ec08488b05b9ef01004885c07402
endbr64
sub	$0x8,%rsp
mov	0x1efb9(%rip),%rax        # 22fc8 <__gmon_start__@Base>
test	%rax,%rax
je	4016 <free@plt-0x69a>

ffd0
call	*%rax

4883c408c3
add	$0x8,%rsp
ret


In [31]:
# Verify we can disassemble the hex correctly
from deep_mca.utils import disassemble_hex

test_block = blocks[100]
print("Original:")
for instr in test_block.instructions:
    print(f"\t{instr}")


print("Re-disassembled via llvm-mc:")
for instr in disassemble_hex(test_block.hex_bytes):
    print(f"\t{instr}")

Original:
	xchg	%ax,%ax
	endbr64
	push	$0x60
	jmp	4020 <free@plt-0x690>
Re-disassembled via llvm-mc:
	nop
	endbr64
	pushq	$96
	jmp	-1566
