In [1]:
import subprocess
from pathlib import Path

import pandas as pd

from deep_mca.utils import disassemble_hex, wrap_asm

In [2]:
raw_data_path = Path("../data/bhive/benchmark/throughput/skl.csv")
df = pd.read_csv(raw_data_path, header=None, names=["hex", "cycles_100"])

df["hex"] = df["hex"].fillna("").astype(str).str.strip()
df = df[df["hex"] != ""]

df["cycles_100"] = pd.to_numeric(df["cycles_100"], errors="coerce")
df = df.dropna(subset=["cycles_100"]).reset_index(drop=True)

In [3]:
df.shape

(314876, 2)

In [4]:
df.head()

Unnamed: 0,hex,cycles_100
0,4183ff0119c083e00885c98945c4b8010000000f4fc139c2,249.0
1,4889de4889c24c89ff,91.0
2,48895d1844886520488945004889e84883c4085b5d415c...,330.0
3,0fb7d5448d40ff8d0cd28d348a4421c689f14c8d0ccd00...,361.0
4,418b4424084d8b3424498d2cc64939ee,100.0


In [5]:
def run_llvm_mca(asm: str, mcpu: str = "skylake", iterations: int = 100) -> float:
    cmd = [
        "llvm-mca",
        "-mtriple=x86_64",
        f"-mcpu={mcpu}",
        f"-iterations={iterations}",
    ]
    proc = subprocess.run(
        cmd,
        input=asm,
        text=True,
        capture_output=True,
        check=False,
    )
    if proc.returncode != 0:
        raise RuntimeError(proc.stderr.strip())

    for line in proc.stdout.splitlines():
        if "Block RThroughput:" in line:
            return float(line.split(":")[1].strip())

    raise RuntimeError("llvm-mca output missing Block RThroughput")

In [6]:
def analyze_row(row, mcpu: str = "skylake", iterations: int = 100, intel: bool = False):
    asm_lines = disassemble_hex(row["hex"], output_intel_syntax=intel)
    asm = wrap_asm(asm_lines)
    rthroughput = run_llvm_mca(asm, mcpu=mcpu, iterations=iterations)
    mca_cycles_100 = rthroughput * iterations

    print(f"True cycles: {row['cycles_100']}, llvm-mca cycles: {mca_cycles_100}")
    print("Source code:")
    for line in asm_lines:
        print(line)

In [7]:
disassemble_hex(df.loc[0]["hex"], output_intel_syntax=False)

['cmpl\t$1, %r15d',
 'sbbl\t%eax, %eax',
 'andl\t$8, %eax',
 'testl\t%ecx, %ecx',
 'movl\t%eax, -60(%rbp)',
 'movl\t$1, %eax',
 'cmovgl\t%ecx, %eax',
 'cmpl\t%eax, %edx']

In [8]:
analyze_row(df.iloc[0])

True cycles: 249.0, llvm-mca cycles: 130.0
Source code:
cmpl	$1, %r15d
sbbl	%eax, %eax
andl	$8, %eax
testl	%ecx, %ecx
movl	%eax, -60(%rbp)
movl	$1, %eax
cmovgl	%ecx, %eax
cmpl	%eax, %edx


In [9]:
analyze_row(df.iloc[1])

True cycles: 91.0, llvm-mca cycles: 80.0
Source code:
movq	%rbx, %rsi
movq	%rax, %rdx
movq	%r15, %rdi


In [10]:
analyze_row(df.iloc[2])

True cycles: 330.0, llvm-mca cycles: 300.0
Source code:
movq	%rbx, 24(%rbp)
movb	%r12b, 32(%rbp)
movq	%rax, (%rbp)
movq	%rbp, %rax
addq	$8, %rsp
popq	%rbx
popq	%rbp
popq	%r12
popq	%r13


In [14]:
import math
import re


class TextAssemblyTokenizer:
    def __init__(self):
        # Vocabularies
        self.vocab = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
        self.reg_vocab = {"<NONE>": 0, "<UNK>": 1}

        # Regex patterns for AT&T syntax
        self.re_reg = re.compile(r"%(\w+)")  # Matches %eax, %r15d
        self.re_imm = re.compile(r"\$([-0-9xA-Fa-f]+)")  # Matches $1, $0xFF
        self.re_mem = re.compile(r"(-?0x[0-9a-f]+|-?\d+)?\((%?\w+)(?:,\s*(%?\w+)(?:,\s*(\d+))?)?\)")
        # Matches -60(%rbp) or (%rax, %rcx, 4)

    def _get_id(self, key, vocab):
        if key not in vocab:
            vocab[key] = len(vocab)
        return vocab[key]

    def normalize_value(self, val_str):
        """Converts hex/decimal strings to log-scaled floats."""
        try:
            val = int(val_str, 0)  # Handles '0x10' and '16'
        except (ValueError, TypeError):
            return 0.0

        if val == 0:
            return 0.0
        sign = 1 if val > 0 else -1
        return sign * math.log2(abs(val) + 1)

    def tokenize_block(self, instr_list):
        """
        Args:
            instr_list: List of strings e.g. ['movl %eax, -60(%rbp)', ...]
        Returns:
            List of structured dictionaries for the Mamba Dataset
        """
        tokenized_block = []

        for line in instr_list:
            # 1. Clean and split mnemonic
            parts = line.strip().split()
            if not parts:
                continue

            mnemonic = parts[0]
            operands_str = "".join(parts[1:])  # Rejoin rest to handle spaces

            instr_data = {"mne_id": self._get_id(mnemonic, self.vocab), "regs": [], "numerical": []}

            # 2. Extract Registers (e.g., %eax)
            # We find ALL registers in the line (source, dest, index, base)
            regs = self.re_reg.findall(operands_str)
            for r in regs:
                instr_data["regs"].append(self._get_id(r, self.reg_vocab))

            # 3. Extract Immediates (e.g., $1)
            imms = self.re_imm.findall(operands_str)
            for imm in imms:
                instr_data["numerical"].append(self.normalize_value(imm))

            # 4. Extract Memory Displacements (e.g., -60 from -60(%rbp))
            # The regex finds the number before the parenthesis
            mem_refs = self.re_mem.findall(operands_str)
            for mem in mem_refs:
                disp_str = mem[0]  # The first group is the displacement
                if disp_str:
                    instr_data["numerical"].append(self.normalize_value(disp_str))

            tokenized_block.append(instr_data)

        return tokenized_block


# --- Usage with your data ---
input_data = [
    "cmpl\t$1, %r15d",
    "sbbl\t%eax, %eax",
    "andl\t$8, %eax",
    "testl\t%ecx, %ecx",
    "movl\t%eax, -60(%rbp)",
]

tokenizer = TextAssemblyTokenizer()
output = tokenizer.tokenize_block(input_data)

# Print verification
for i, raw in zip(output, input_data):
    print(f"Original: {raw}")
    print(f"Tokens:   MnemonicID: {i['mne_id']} | RegIDs: {i['regs']} | nums: {i['numerical']}")
    print("-" * 50)

Original: cmpl	$1, %r15d
Tokens:   MnemonicID: 4 | RegIDs: [2] | nums: [1.0]
--------------------------------------------------
Original: sbbl	%eax, %eax
Tokens:   MnemonicID: 5 | RegIDs: [3, 3] | nums: []
--------------------------------------------------
Original: andl	$8, %eax
Tokens:   MnemonicID: 6 | RegIDs: [3] | nums: [3.169925001442312]
--------------------------------------------------
Original: testl	%ecx, %ecx
Tokens:   MnemonicID: 7 | RegIDs: [4, 4] | nums: []
--------------------------------------------------
Original: movl	%eax, -60(%rbp)
Tokens:   MnemonicID: 8 | RegIDs: [3, 5] | nums: [-5.930737337562887]
--------------------------------------------------


In [16]:
flat = []
for d in output:
    flat.append(d["mne_id"])
    flat.extend(d["regs"])
    flat.extend(d["numerical"])
    print(d)
print(flat)

{'mne_id': 4, 'regs': [2], 'numerical': [1.0]}
{'mne_id': 5, 'regs': [3, 3], 'numerical': []}
{'mne_id': 6, 'regs': [3], 'numerical': [3.169925001442312]}
{'mne_id': 7, 'regs': [4, 4], 'numerical': []}
{'mne_id': 8, 'regs': [3, 5], 'numerical': [-5.930737337562887]}
[4, 2, 1.0, 5, 3, 3, 6, 3, 3.169925001442312, 7, 4, 4, 8, 3, 5, -5.930737337562887]


In [None]:
for i, raw in zip(output, input_data):
    print(f"Original: {raw}")
    print(f"Tokens:   MnemonicID: {i['mne_id']} | RegIDs: {i['regs']} | nums: {i['numerical']}")
    print("-" * 50)