# Assemblage adoption code for jTrans


Please put the file containing function ids you want to train on into a file, and give it as TXTINPUT i, dbfile as the Assemblage db file path, dataset_path as the binaries path

In [None]:
import json
import sys
import os
import sqlite3
import glob
from tqdm import tqdm
import shutil
import hashlib
import pandas as pd

def getmd5(s):
    return hashlib.md5(s.encode()).hexdigest()

dbfile = './sept25.sqlite'
dataset_path = './dataset_sept25'
flatten_dir = "./dataset"

TXTINPUT = "train_function_id.txt"
CSVOUT = TXTINPUT.replace('.txt', '.csv')

# Cols function_id, function_name,rvas.start, binary_id, binary_compiler, binary_optimization, binary_github_url
with open(TXTINPUT) as f:
    fids = [int(x.strip()) for x in f.readlines()]
    
connection = sqlite3.connect(dbfile)
cursor = connection.cursor()

csvrows = []


infos = cursor.execute('''SELECT f.id, f.name, r.start, b.id AS binary_id, b.toolset_version, b.optimization, b.github_url
    FROM functions f
    JOIN rvas r ON r.function_id = f.id
    JOIN binaries b ON b.id = f.binary_id;''')
for info in infos:
    csvrows.append([x for x in info]) 

df = pd.DataFrame(csvrows, columns=['function_id', 'function_name', 'rva_start', 'binary_id', 'binary_compiler', 'binary_optimization', 'binary_github_url'])
df = df[df['function_id'].isin(fids)]
df.to_csv(CSVOUT, index=False)



Sort out the binary files

In [None]:
# This block flattens the dataset folder into flatten folders, each binary stay in its own folder, folder name is binary id, along with its pdbs

import sys
import os
import sqlite3
import glob
from tqdm import tqdm
import shutil
import hashlib

def getmd5(s):
    return hashlib.md5(s.encode()).hexdigest()

dbfile = 'sept25.sqlite'
dataset_path = 'dataset_sept25'
flatten_dir = "dataset"

if os.path.exists(flatten_dir):
    os.system(f"rm -rf {flatten_dir}")
os.makedirs(flatten_dir)

connection = sqlite3.connect(dbfile)
cursor = connection.cursor()

import pandas as pd

df = pd.read_csv(CSVOUT)
binaryids = df['binary_id'].unique()

infos = cursor.execute('SELECT id, path, file_name, optimization, github_url, toolset_version FROM binaries;')
for binid, path, file_name, opt, github_url,toolset_version in tqdm(infos):
    full_path = os.path.join(dataset_path, path.replace("\\", "/"))
    if int(binid) not in binaryids:
        continue
    if not os.path.isfile(full_path):
        print("Missing", full_path)
        continue
    if not os.path.isdir(os.path.join(flatten_dir, str(binid))):
        os.makedirs(os.path.join(flatten_dir, str(binid)))
    # Original datautils/dataset/libcap-git-setcap-O2-8dc43f20ea80b7703f6973a1ea86e8b8
    shutil.copy(full_path, os.path.join(flatten_dir, str(binid), f"{binid}_{file_name}-{toolset_version}-{opt}-{getmd5(github_url)}"))
    newcursor = connection.cursor()
    pdbs = newcursor.execute('SELECT pdb_path FROM pdbs where binary_id = ?', (binid,))
    for pdb in pdbs:
        full_path = os.path.join(dataset_path, pdb[0].replace("\\", "/"))
        if not os.path.isfile(full_path):
            print("Missing", full_path)
            continue
        shutil.copy(full_path, os.path.join(flatten_dir, str(binid), os.path.basename(pdb[0].replace("\\", "/"))))


The jTrans code from origial author

In [None]:
# jTrans code, not modified much
import os
import subprocess
import multiprocessing
import time
from util.pairdata import pairdata
from subprocess import STDOUT, check_output
import glob
import shutil

ida_path="idat64"
script_path = "./process.py"


os.system("rm -rf extract&&mkdir extract")
os.system("rm -rf log&&mkdir log")
os.system("rm -rf idb&&mkdir idb")

def getTarget(path, prefixfilter=None):
    return [x for x in glob.glob(f'{path}/**/*', recursive=True) if os.path.isfile(x) and (prefixfilter is None or any([x.startswith(y) for y in prefixfilter]))]

def cmd_warp(cmd, timeout):
    output = check_output(cmd, stderr=STDOUT, timeout=timeout)
    return

dataset_dir = "dataset"
start = time.time()
target_list = getTarget(dataset_dir)

pool = multiprocessing.Pool(processes=128)
for target in target_list:
    if target.lower().endswith("lib") or target.lower().endswith("pdb"):
        continue
    filename = os.path.basename(target)
    filename_strip = filename
    cmd = [ida_path, f'-Llog/{filename}.log', '-c', '-A', f'-S{script_path}', f'-oidb/{filename}.idb', f'{target}']
    pool.apply_async(cmd_warp, args=(cmd, 600, ))


pool.close()
pool.join()

pairdata("extract")

This code block will create a file binid2hash2id.json, and a folder addr_ref for performance issue. 
If we query sqlite all the time it's super slow

In [None]:
import glob
import hashlib
import sys
import os
import sqlite3
import glob
from tqdm import tqdm
import shutil
import hashlib
import pickle
from hashlib import sha256
from tqdm import tqdm
import json

def getmd5(s):
    return hashlib.md5(s.encode()).hexdigest()

def sha256sum(b):
    h1 = sha256()
    h1.update(b)
    return h1.digest().hex()

dbfile = 'sept25.sqlite'

connection = sqlite3.connect(dbfile)
cursor = connection.cursor()


db = {}

for function_name, binid, rva in tqdm(cursor.execute(f'SELECT f.name, f.binary_id, r.start FROM functions f JOIN rvas r ON f.id==r.function_id;')):
    if binid not in db:
        db[binid] = {}
    db[binid][function_name] = rva
    
os.system(f"rm -rf addr_ref")
os.system(f"mkdir -p addr_ref")
for binid in tqdm(db):
    with open(f'./addr_ref/{binid}.json', 'w') as f:
        json.dump(db[binid], f)

Addrress convert codes, it reads extract and convert addresses to extracted_modify  
Convert jTrans info {sub_xxxx:[addr, insts, bytes, cfg, bai_feture]} to {real_name:[addr, insts, bytes, cfg, bai_feture]}

In [None]:
import glob
import hashlib
import sys
import os
import sqlite3
import glob
from tqdm import tqdm
import shutil
import hashlib
import pickle
from hashlib import sha256
import pefile
import multiprocessing

def getmd5(s):
    return hashlib.md5(s.encode()).hexdigest()

import json


def sha256sum(b):
    h1 = sha256()
    h1.update(b)
    return h1.digest().hex()

from collections import defaultdict

def run(f):
    binid = os.path.basename(f).split("_")[0]
    with open(f'addr_ref/{binid}.json', 'r') as fh:
        name2addr = json.load(fh)
    addr2name = {name2addr[x]:x for x in name2addr}
    # sort by address
    addr2name = {k: v for k, v in sorted(addr2name.items(), key=lambda item: item[0], reverse=True)}
    fpath = os.path.join("dataset", binid, os.path.basename(f).split('_extract.pkl')[0])

    with open(f, "rb") as fh:
        saved_index = pickle.load(fh)
    keys_stored = list(saved_index.keys())
    for x in keys_stored:
        addr = saved_index[x][0]
        peobj = pefile.PE(fpath, fast_load=True)
        relative_addr = saved_index[x][0] - peobj.OPTIONAL_HEADER.ImageBase
        if relative_addr in addr2name:
            name = addr2name[relative_addr]
        else:
            for addr, name in addr2name.items():
                if addr <= relative_addr:
                    name = addr2name[addr]
                    break
        saved_index[name] = saved_index.pop(x)
    with open(os.path.join("./extracted_modify", os.path.basename(f)), "wb") as fh:
        pickle.dump(saved_index, fh)
        

pool = multiprocessing.Pool(processes=128)

for f in tqdm(glob.glob("extract/**/*", recursive=True)):
    if f.endswith("extract.pkl"):
        pool.apply_async(run, args=(f,))

pool.close()
pool.join()


Output to extract_selected folder, which will be used to eval

In [None]:
import glob
import hashlib
import sys
import os
import sqlite3
import glob
from tqdm import tqdm
import shutil
import hashlib
import pickle
from hashlib import sha256
import pefile
import multiprocessing

def getmd5(s):
    return hashlib.md5(s.encode()).hexdigest()

import json


def sha256sum(b):
    h1 = sha256()
    h1.update(b)
    return h1.digest().hex()

from collections import defaultdict

import pandas as pd
        
df = pd.read_csv(CSVOUT)
print("Read csv file")
# Calculate the hash of each function
for f in tqdm(glob.glob("extracted_modify/**/*", recursive=True)):
    if f.endswith("extract.pkl"):
        binid = os.path.basename(f).split("_")[0]
        df_selected = df[df['binary_id'] == int(binid)]
        new_index = {}
        with open(f, "rb") as fh:
            saved_index = pickle.load(fh)
        for func_name in saved_index:
            if func_name in df_selected['function_name'].values:
                new_index[func_name] = saved_index[func_name]
        # print(len(new_index), len(df_selected['function_name'].values)) 
        with open(os.path.join("./extract_selected", os.path.basename(f)), "wb") as fh:
            pickle.dump(new_index, fh)

