In [1]:
#!g1.1
# %pip install -U plotly==5.11.0 pandas==1.3.5 pytorch_lightning==1.8.0 numpy==1.21.6 ipywidgets==8.0.2 
# %pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
# %pip install tensorboard==2.11.0

In [2]:
#!g1.1
import os
import glob
import random
import pickle
import shutil

import plotly.express as px
import pandas as pd
import torch
import pytorch_lightning as pl
import numpy as np
import torchtext
import torch_geometric
from tqdm import tqdm
import identify_x86_data

## Dataset loading

In [3]:
#!g1.1

# EXECUTABLE = 'bin_zsh5'
# EXECUTABLE = 'bin_gzip'
# EXECUTABLE = 'usr_lib_gcc_i686-linux-gnu_8_lto1'

random.seed(42)

superset_files = glob.glob('superset/**/*.superset', recursive=True)
superset_files = list(sorted((os.path.getsize(f), f) for f in superset_files))
superset_files = [f for s, f in superset_files if s < 10_000_000]
random.shuffle(superset_files)
# print(*superset_files, sep='\n')
print(f"Using {len(superset_files)} superset files!")

train_files, test_files = superset_files[:40], superset_files[40:]

Using 53 superset files!


In [4]:
MAX_ISN_SIZE = 15

def load_superset(filename):
    df = pd.read_parquet(filename)
    df.code = df.code.map(lambda x: identify_x86_data.INSTR_CODES[x])
    df.set_index('addr', inplace=True)
    df.sort_index(inplace=True)
    return df

In [5]:
from torchtext.vocab import vocab as make_vocab
from collections import OrderedDict, Counter

# TODO: this vocab should be built over the whole training dataset

counts = Counter()

for filename in tqdm(train_files):
    df = load_superset(filename)
    df_counts = df.code.value_counts().to_dict()
    counts.update(df_counts)

# print(counts)

known = { x[0]: i for i, x in enumerate(counts.most_common(200)) }
vocab = make_vocab(known, specials=['INVALID', 'UNKNOWN'])
vocab.set_default_index(vocab['UNKNOWN'])

100%|██████████| 40/40 [00:02<00:00, 18.01it/s]


In [6]:
#!g1.1
vocab['Dec_r32']

8

## Build graph from the loaded data

In [7]:
#!g1.1
from torch_geometric.data import Data
from typing import List, Tuple
import gc

def encode_instructions(instr):
    code = torch.tensor(vocab(instr.code.to_list()))
    size = torch.tensor(instr['size'].map(lambda x: x-1).values)
    labels = torch.tensor(instr['label'].values)

    return code, size, labels

EDGE_NEXT = 0
EDGE_PREV = 1
EDGE_OVERLAP = 2
EDGE_RELCOUNT = 3

class EdgesBuilder:
    def __init__(self):
        self.idx_buffer = []
        self.ty_buffer = []
        self.edge_count = 0
        self.edge_idx_parts = []
        self.edge_ty_parts = []

    def add_edge(self, src, dst, kind):
        self.idx_buffer.append((src, dst))
        self.ty_buffer.append(kind)
        self.edge_count += 1
        
        if len(self.idx_buffer) >= 0x80000: # TODO: tune
            self.edge_idx_parts.append(torch.tensor(self.idx_buffer, dtype=torch.long))
            self.edge_ty_parts.append(torch.tensor(self.ty_buffer, dtype=torch.long))
            self.idx_buffer.clear()
            self.ty_buffer.clear()

    def build(self):
        self.edge_idx_parts.append(torch.tensor(self.idx_buffer, dtype=torch.long))
        self.edge_ty_parts.append(torch.tensor(self.ty_buffer, dtype=torch.long))
        self.idx_buffer.clear()
        self.ty_buffer.clear()

        edge_idx = torch.cat(self.edge_idx_parts)
        self.edge_idx_parts.clear()
        gc.collect()
        
        edge_ty = torch.cat(self.edge_ty_parts)
        self.edge_ty_parts.clear()
        gc.collect()

        return edge_idx, edge_ty

    def __len__(self):
        return self.edge_count

def build_executable_graph(df):
    G = Data()
    G.num_nodes = df.shape[0]
    G.x_code, G.x_size, G.y = encode_instructions(df)
    # the classes are stored as booleans for efficiency
    # convert them to long for the loss function
    G.y = G.y.to(torch.long)

    edges = EdgesBuilder()

    t = tqdm(df.iterrows(), total=df.shape[0])
    for addr, x in t:
        i = df.index.get_loc(addr)
        next_addr = addr + x.size
        try:
            j = df.index.get_loc(next_addr)

            edges.add_edge(i, j, EDGE_NEXT)
            edges.add_edge(j, i, EDGE_PREV)
        except KeyError:
            pass

        for o in range(addr+1, next_addr):
            try:
                j = df.index.get_loc(o)
                edges.add_edge(i, j, EDGE_OVERLAP)
                edges.add_edge(j, i, EDGE_OVERLAP)
            except KeyError:
                pass
        if addr % 0x1000 == 0:
            t.set_description(f'edges: {len(edges)}')

    edge_idx, edge_ty = edges.build()
    del edges
    gc.collect()

    # print(edge_idx.shape)
    # print(edge_idx)

    # print(edge_ty)

    G.num_edges = edge_idx.shape[0]
    G.edge_index = torch.swapaxes(edge_idx, 0, 1)
    G.edge_type = edge_ty

    return G

In [8]:
#!g1.1
# TODO: use PyG's Dataset class

# TODO: now that we have a list of superset files we should be able to create a dataset that caches the graph conversion

from torch_geometric.data import Dataset

class IdentifyDataset(Dataset):
    def __init__(self, files, root, transform=None, pre_transform=None):
        self.files = files
        super().__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return self.files

    @property
    def processed_file_names(self):
        return [f'{f}.pt' for f in self.files]

    def download(self):
        for filename, raw_path in zip(self.raw_file_names, self.raw_paths):
            os.makedirs(os.path.dirname(raw_path), exist_ok=True)
            shutil.copy(filename, raw_path)

    def process(self):
        for filename, processed_path in zip(tqdm(self.raw_paths), self.processed_paths):
            df = load_superset(filename)

            # TODO: I want to do this transformation in rust (python ~~sucks ass~~ is slow)
            # either way we would need to implement graph construction in rust for deployment
            G = build_executable_graph(df)

            if self.pre_filter is not None:
                data_list = [data for data in data_list if self.pre_filter(data)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(data) for data in data_list]

            os.makedirs(os.path.dirname(processed_path), exist_ok=True)
            torch.save(G, processed_path)
    
    def get(self, idx):
        data = torch.load(self.processed_paths[idx])
        if self.transform is not None:
            data = self.transform(data)
        return data

    def len(self):
        return len(self.processed_file_names)

train_dataset = IdentifyDataset(train_files, 'data/train')
test_dataset = IdentifyDataset(test_files, 'data/test')

# G = build_executable_graph(df)

In [9]:
#!g1.1
from torch_geometric.nn import Sequential, RGCNConv, Linear
from torch.nn import Embedding, ReLU, Sigmoid
import torchmetrics
import torchmetrics.classification

In [10]:
#!g1.1

from torch import Tensor

class IdentifyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        size_embed_size = 4
        code_embed_size = 32

        @torch.jit.script
        def cat(x1, x2):
            return torch.cat([x1, x2], dim=1)

        self.model = Sequential('x_code, x_size, edge_index, edge_type', [
            (Embedding(num_embeddings=MAX_ISN_SIZE, embedding_dim=size_embed_size), 'x_size -> x_size'),
            (Embedding(num_embeddings=len(vocab), embedding_dim=code_embed_size), 'x_code -> x_code'),
            (cat, 'x_size, x_code -> x'),
            (RGCNConv(size_embed_size + code_embed_size, 24, EDGE_RELCOUNT).jittable(), 'x, edge_index, edge_type -> x'),
            ReLU(inplace=True),
            (RGCNConv(24, 16, EDGE_RELCOUNT).jittable(), 'x, edge_index, edge_type -> x'),
            ReLU(inplace=True),
            (RGCNConv(16, 8, EDGE_RELCOUNT).jittable(), 'x, edge_index, edge_type -> x'),
            ReLU(inplace=True),
            (RGCNConv(8, 4, EDGE_RELCOUNT).jittable(), 'x, edge_index, edge_type -> x'),
            ReLU(inplace=True),
            Linear(4, 2),
        ])
    
    def forward(self, x_code: Tensor, x_size: Tensor, edge_index: Tensor, edge_type: Tensor) -> Tensor:
        return self.model(x_code, x_size, edge_index, edge_type)


class LightningModel(pl.LightningModule):
    def __init__(self):
        super(LightningModel, self).__init__()

        self.model = IdentifyModel()

        self.train_accuracy = torchmetrics.classification.BinaryAccuracy()
        self.train_precision = torchmetrics.classification.BinaryPrecision()
        self.train_recall = torchmetrics.classification.BinaryRecall()

        self.valid_accuracy = torchmetrics.classification.BinaryAccuracy()
        self.valid_precision = torchmetrics.classification.BinaryPrecision()
        self.valid_recall = torchmetrics.classification.BinaryRecall()


    def forward(self, x_code: Tensor, x_size: Tensor, edge_index: Tensor, edge_type: Tensor):
        x_out = self.model(x_code, x_size, edge_index, edge_type)

        return x_out

    def training_step(self, batch, batch_index):
        x_code, x_size, edge_index, edge_type = \
            batch.x_code, batch.x_size, batch.edge_index, batch.edge_type

        x_out = self.forward(x_code, x_size, edge_index, edge_type)

        loss = torch.nn.functional.cross_entropy(x_out, batch.y)

        # metrics here
        pred = x_out.argmax(-1)
        label = batch.y
        
        self.train_accuracy(pred, label)
        self.train_precision(pred, label)
        self.train_recall(pred, label)

        self.log("loss/train", loss)
        self.log("accuracy/train", self.train_accuracy, on_step=True, on_epoch=False)
        self.log("recall/train", self.train_recall, on_step=True, on_epoch=False)
        self.log("precision/train", self.train_precision, on_step=True, on_epoch=False)

        return loss

    def validation_step(self, batch, batch_index):
        x_code, x_size, edge_index, edge_type = \
            batch.x_code, batch.x_size, batch.edge_index, batch.edge_type

        x_out = self.forward(x_code, x_size, edge_index, edge_type)

        #loss = torch.nn.functional.cross_entropy(x_out, batch.y)

        pred = x_out.argmax(-1)

        self.valid_accuracy(pred, batch.y)
        self.valid_precision(pred, batch.y)
        self.valid_recall(pred, batch.y)

        self.log("accuracy/val", self.valid_accuracy, on_step=True, on_epoch=True)
        self.log("recall/val", self.valid_recall, on_step=True, on_epoch=True)
        self.log("precision/val", self.valid_precision, on_step=True, on_epoch=True)

        return x_out, pred, batch.y

    def validation_epoch_end(self, validation_step_outputs):
        val_loss = 0.0
        num_correct = 0
        num_total = 0
        num_tp = 0
        num_tn = 0
        num_fp = 0
        num_fn = 0

        for output, pred, labels in validation_step_outputs:
            val_loss += torch.nn.functional.cross_entropy(output, labels, reduction="sum")

        self.log("loss/val", val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr = 3e-4)

In [11]:
#!g1.1
model = LightningModel()

In [12]:
#!g1.1
model

LightningModel(
  (model): IdentifyModel(
    (model): Sequential(
      (0): Embedding(15, 4)
      (1): Embedding(200, 32)
      (2): <torch.jit.ScriptFunction object at 0x7f1e2d7baa20>
      (3): RGCNConvJittable_52c1d6(36, 24, num_relations=3)
      (4): ReLU(inplace=True)
      (5): RGCNConvJittable_52c8ae(24, 16, num_relations=3)
      (6): ReLU(inplace=True)
      (7): RGCNConvJittable_52d0fb(16, 8, num_relations=3)
      (8): ReLU(inplace=True)
      (9): RGCNConvJittable_52db01(8, 4, num_relations=3)
      (10): ReLU(inplace=True)
      (11): Linear(4, 2, bias=True)
    )
  )
  (train_accuracy): BinaryAccuracy()
  (train_precision): BinaryPrecision()
  (train_recall): BinaryRecall()
  (valid_accuracy): BinaryAccuracy()
  (valid_precision): BinaryPrecision()
  (valid_recall): BinaryRecall()
)

In [13]:
#!g1.1
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger

print("Cuda is available:", torch.cuda.is_available())

train_loader = DataLoader(train_dataset, batch_size=None)
test_loader = DataLoader(test_dataset, batch_size=None)
val_loader = DataLoader(test_dataset, batch_size=None)

model = LightningModel()
num_epochs = 2000
# val_check_interval = len(train_loader)

trainer = pl.Trainer(
    max_epochs = num_epochs,
    # val_check_interval = val_check_interval,
    log_every_n_steps = 1,
    accelerator = 'gpu',
    enable_progress_bar = False,
)
trainer.fit(model, train_loader, val_loader)

Cuda is available: True


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type            | Params
----------------------------------------------------
0 | model           | IdentifyModel   | 12.2 K
1 | train_accuracy  | BinaryAccuracy  | 0     
2 | train_precision | BinaryPrecision | 0     
3 | train_recall    | BinaryRecall    | 0     
4 | valid_accuracy  | BinaryAccuracy  | 0     
5 | valid_precision | BinaryPrecision | 0     
6 | valid_recall    | BinaryRecall    | 0     
----------------------------------------------------
12.2 K    Trainable params
0         Non-trainable params
12.2 K    Total params
0.049     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=2000` reached.


In [16]:
#!g1.1
# TODO: this is not that easy...
# model.eval()
# for param in model.parameters():
    # param.requires_grad = False
model_jit = torch.jit.script(model.model)
print("Model JIT:", model_jit)

model_jit.save(f"model_jit.pt")

# torch.onnx.export(model_jit, (
#     torch.tensor([0], dtype=torch.long),
#     torch.tensor([0], dtype=torch.long),
#     torch.tensor([[0, 0]], dtype=torch.long),
#     torch.tensor([0], dtype=torch.long),
# ), f"{EXECUTABLE}.onnx", verbose=True)
# with torch.no_grad():

    # torch.onnx.export(model, (G.x_code, G.x_size, G.edge_index, G.edge_type), f"{EXECUTABLE}.onnx", verbose=True)
# torch.jit.save(model, f"{EXECUTABLE}.pt")

Model JIT: RecursiveScriptModule(
  original_name=IdentifyModel
  (model): RecursiveScriptModule(
    original_name=Sequential_530972
    (module_0): RecursiveScriptModule(original_name=Embedding)
    (module_1): RecursiveScriptModule(original_name=Embedding)
    (module_3): RecursiveScriptModule(
      original_name=RGCNConvJittable_52ee10
      (aggr_module): RecursiveScriptModule(original_name=MeanAggregation)
    )
    (module_4): RecursiveScriptModule(original_name=ReLU)
    (module_5): RecursiveScriptModule(
      original_name=RGCNConvJittable_52f810
      (aggr_module): RecursiveScriptModule(original_name=MeanAggregation)
    )
    (module_6): RecursiveScriptModule(original_name=ReLU)
    (module_7): RecursiveScriptModule(
      original_name=RGCNConvJittable_52ff9a
      (aggr_module): RecursiveScriptModule(original_name=MeanAggregation)
    )
    (module_8): RecursiveScriptModule(original_name=ReLU)
    (module_9): RecursiveScriptModule(
      original_name=RGCNConvJittable_5

In [15]:
model.model.model[-1].weight

Parameter containing:
tensor([[-0.5333, -0.9070,  0.8564, -0.1676],
        [-0.2439,  0.8260, -0.6739,  0.1237]], requires_grad=True)