# Next Instruction Prediction Training


In [1]:
import torch

torch.cuda.is_available()

  from .autonotebook import tqdm as notebook_tqdm


True

# DATASET GENERATION

In [2]:
#!/usr/bin/env python3

import sys,os
from elftools.elf.elffile import ELFFile
from elftools.elf.segments import Segment

filePath = './../../binaries/gnuit/src/gitfm'
fh = open(filePath, 'rb')
bin_bytearray = bytearray(fh.read())

In [3]:


from capstone import *

from capstone.x86 import *


address_inst = {}
with open('./data/instruction_clusters.txt', 'w') as data_file:
    with open(filePath, 'rb') as f:
        elf = ELFFile(f)
        dwarfinfo = elf.get_dwarf_info()
        aranges = dwarfinfo.get_aranges()
        print(len(aranges.entries))
    #     for arange in aranges.entries:
    #         print(arange)
        for arange in aranges.entries:

            entry = arange.begin_addr
            exit  = arange.begin_addr + arange.length
            ops = bin_bytearray[entry: exit]

            md = Cs(CS_ARCH_X86, CS_MODE_64)
            md.detail = True
            for inst in md.disasm(ops, entry):

                address_inst[hex(inst.address)] = inst
                data_file.write(inst.mnemonic+" "+inst.op_str+";")
            data_file.write('\n')
    #             print( inst.mnemonic+"  "+inst.op_str)


341


# Creating the pipeline

In [4]:
from transformers import BertTokenizer, BertForNextSentencePrediction
import torch

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

tokenizer = BertTokenizer.from_pretrained("./binary-tokenizer")
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
delim = ';'
with open('./data/instruction_clusters.txt', 'r') as fp:
    text = fp.read().split('\n')

In [6]:
text[1]

'endbr64 ;push rbp;mov rbp, rsp;mov eax, dword ptr [rip + 0x2b9f5];cmp eax, 6;jle 0x501a;mov eax, dword ptr [rip + 0x2a036];test eax, eax;je 0x5008;mov eax, dword ptr [rip + 0x2b9e4];cmp eax, 0xb;jle 0x501a;mov eax, 1;jmp 0x501f;mov eax, dword ptr [rip + 0x2b9d2];cmp eax, 5;jle 0x501a;mov eax, 1;jmp 0x501f;mov eax, 0;pop rbp;ret ;'

We need to split sentences into consecutive, and non-consecutive sequences.

We have to deal with edge-cases too - for example where there is only a single sentence within a paragraph as with the three examples above (in comparison to below where we can easily split into multiple sentences).

In [7]:
text[51].split(delim)

['endbr64 ',
 'push rbp',
 'mov rbp, rsp',
 'sub rsp, 0x10',
 'mov dword ptr [rbp - 4], edi',
 'mov eax, dword ptr [rip + 0x24da4]',
 'cmp eax, 1',
 'jne 0xbc8e',
 'movzx eax, byte ptr [rip + 0x24d8d]',
 'movzx eax, al',
 'sar eax, 6',
 'and eax, 1',
 'cmp dword ptr [rbp - 4], eax',
 'je 0xbd57',
 'cmp dword ptr [rbp - 4], 1',
 'jne 0xbcc1',
 'mov rax, qword ptr [rip + 0x2404d]',
 'test rax, rax',
 'je 0xbd2e',
 'mov rax, qword ptr [rip + 0x2403d]',
 'lea rdx, [rip - 0xc66]',
 'mov esi, 1',
 'mov rdi, rax',
 'call 0x47e0',
 'jmp 0xbd2e',
 'mov rax, qword ptr [rip + 0x23fe0]',
 'test rax, rax',
 'je 0xbce8',
 'mov rax, qword ptr [rip + 0x23fd4]',
 'lea rdx, [rip - 0xc8f]',
 'mov esi, 1',
 'mov rdi, rax',
 'call 0x47e0',
 'mov dword ptr [rip + 0x24d1a], 0',
 'mov dword ptr [rip + 0x24d14], 0',
 'movzx eax, byte ptr [rip + 0x24d06]',
 'and eax, 0xffffffbf',
 'mov byte ptr [rip + 0x24cfd], al',
 'movzx eax, byte ptr [rip + 0x24cf6]',
 'shr al, 7',
 'cmp al, 1',
 'jne 0xbd2e',
 'mov dword p

We'll assign a 50% probability of using the genuine next sentence, and 50% probability of using another random sentence.

To make this simpler, we'll create a *'bag'* of individual sentences to pull from when selecting a random sentence B.

In [8]:
bag = [instruction for instruction_cluster in text for instruction in instruction_cluster.split(delim)  if instruction!= '']
bag_size = len(bag)
print(bag_size)

33455


In [9]:
bag

['call 0x4810',
 'endbr64 ',
 'push rbp',
 'mov rbp, rsp',
 'mov eax, dword ptr [rip + 0x2b9f5]',
 'cmp eax, 6',
 'jle 0x501a',
 'mov eax, dword ptr [rip + 0x2a036]',
 'test eax, eax',
 'je 0x5008',
 'mov eax, dword ptr [rip + 0x2b9e4]',
 'cmp eax, 0xb',
 'jle 0x501a',
 'mov eax, 1',
 'jmp 0x501f',
 'mov eax, dword ptr [rip + 0x2b9d2]',
 'cmp eax, 5',
 'jle 0x501a',
 'mov eax, 1',
 'jmp 0x501f',
 'mov eax, 0',
 'pop rbp',
 'ret ',
 'endbr64 ',
 'push rbp',
 'mov rbp, rsp',
 'mov eax, dword ptr [rip + 0x2b8b1]',
 'cmp eax, 1',
 'sete al',
 'movzx eax, al',
 'pop rbp',
 'ret ',
 'endbr64 ',
 'push rbp',
 'mov rbp, rsp',
 'push rbx',
 'sub rsp, 0x38',
 'mov dword ptr [rbp - 0x34], edi',
 'mov dword ptr [rbp - 0x28], 0',
 'mov dword ptr [rbp - 0x24], 0',
 'mov eax, dword ptr [rip + 0x2b97e]',
 'mov dword ptr [rbp - 0x20], eax',
 'mov eax, dword ptr [rip + 0x2b979]',
 'mov dword ptr [rbp - 0x1c], eax',
 'mov eax, 0',
 'call 0xc866',
 'cmp dword ptr [rbp - 0x34], 0',
 'jne 0x5094',
 'mov eax

And now we create our 50/50 NIP training data.

In [10]:
import random

history = []
next_instruction = []
label = []

page_len = 5
instruction_pages = []
for instruction_cluster in text:
    instructions = [
        instruction for instruction in instruction_cluster.split(delim) if instruction != ''
    ]
    if len(instructions)>page_len:
        
        for i in range(0,len(instructions),page_len):
            instruction_pages.append(instructions[i:i+page_len])
        
print(len(instruction_pages))
print(instruction_pages[0])

for instruction_page in instruction_pages:
    
#     instructions = [
#         instruction for instruction in instruction_page.split(';') if instruction != ''
#     ]
    
    
#     num_instructions = len(instruction_page)
    
    

#     start = random.randint(0, num_instructions-2)
    # 50/50 whether is IsNextSentence or NotNextSentence
    if random.random() >= 0.5:
        # this is IsNextSentence
        history.append(delim.join(instruction_page[:-1]))
        next_instruction.append(instruction_page[-1])
        label.append(0)
    else:
        index = random.randint(0, bag_size-1)
        # this is NotNextSentence
        history.append(delim.join(instruction_page[:-1]))
        next_instruction.append(bag[index])
        label.append(1)

6827
['endbr64 ', 'push rbp', 'mov rbp, rsp', 'mov eax, dword ptr [rip + 0x2b9f5]', 'cmp eax, 6']


In [11]:
print(len(label))
for i in range(3):
    print(label[i])
    print('->',history[i] , '\n')
    print('# ',next_instruction[i] , '\n')

6827
0
-> endbr64 ;push rbp;mov rbp, rsp;mov eax, dword ptr [rip + 0x2b9f5] 

#  cmp eax, 6 

1
-> jle 0x501a;mov eax, dword ptr [rip + 0x2a036];test eax, eax;je 0x5008 

#  call 0x4960 

0
-> cmp eax, 0xb;jle 0x501a;mov eax, 1;jmp 0x501f 

#  mov eax, dword ptr [rip + 0x2b9d2] 



Our data is now ready for tokenization, this time we truncate/pad each token to the same length of *512* tokens.

In [12]:
inputs = tokenizer(history, next_instruction, return_tensors='pt', max_length=128, truncation=True, padding='max_length')

In [13]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

We can see that the *token_type_ids* tensors have been built correctly (eg **1** indicating sentence B tokens) by checking the first instance of *token_type_ids*:

In [14]:
inputs.token_type_ids[0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

The **0** tokens following our sentence B tokens correspond to *PAD* tokens.

Alongside this, we need to create a *labels* tensor too - which corresponds to the values contained within our `label` variable. Our *labels* tensor must be a *LongTensor*, and we will need to transpose the tensor so that it matches our other tensors' dimensionality.

In [15]:
inputs['labels'] = torch.LongTensor([label]).T

In [16]:
inputs.labels[:10]

tensor([[0],
        [1],
        [0],
        [1],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0]])

The `inputs` tensors are now ready, and we can begin building the model input pipeline for training. We first create a PyTorch dataset from our data.

In [17]:
class MeditationsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

Initialize our data using the `MeditationDataset` class.

In [18]:
dataset = MeditationsDataset(inputs)

And initialize the dataloader, which we'll be using to load our data into the model during training.

In [19]:
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

Now we can move onto setting up the training loop. First we setup GPU/CPU usage.

In [20]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# and move our model over to the selected device
model.to(device)

BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

Activate the training mode of our model, and initialize our optimizer (Adam with weighted decay - reduces chance of overfitting).

In [21]:
from transformers import AdamW

# activate training mode
model.train()
# initialize optimizer
optim = AdamW(model.parameters(), lr=5e-6)



In [22]:
['__annotations__', '__class__', '__contains__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__post_init__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'attentions', 'clear', 'copy', 'fromkeys', 'get', 'hidden_states', 'items', 'keys', 'logits', 'loss', 'move_to_end', 'pop', 'popitem', 'setdefault', 'to_tuple', 'update', 'values']

['__annotations__',
 '__class__',
 '__contains__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__post_init__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'attentions',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'hidden_states',
 'items',
 'keys',
 'logits',
 'loss',
 'move_to_end',
 'pop',
 'popitem',
 'setdefault',
 'to_tuple',
 'update',
 'values']

In [23]:
from sklearn.metrics import precision_recall_fscore_support , accuracy_score
import numpy as np


Now we can move onto the training loop, we'll train for a couple of epochs (change `epochs` to modify this).

In [None]:
from tqdm import tqdm  # for our progress bar

epochs = 10000

for epoch in range(epochs):
    # setup loop with TQDM and dataloader
    loop = tqdm(loader, leave=True)
    
    
    predictions_all, ground_truths_all = None, None
    for N,batch in enumerate(loop):

        optim.zero_grad()
        # pull all tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)
        # process
        outputs = model(input_ids, attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        labels=labels)
#         print(torch.nn.functional.softmax(outputs.logits, dim=-1))
        prediction = torch.argmax(outputs.logits, axis=-1)
        prediction = prediction.detach().cpu().numpy().flatten()
        ground_truth = labels.detach().cpu().numpy().flatten()
        
        if N==0:
            predictions_all = prediction
            ground_truths_all = ground_truth
        else:
            predictions_all   = np.concatenate((predictions_all, prediction))
            ground_truths_all = np.concatenate((ground_truths_all, ground_truth))
            
#         predictions_all.append(prediction)
#         ground_truths_all.append(ground_truth)
        
#         print(ground_truth.flatten())
#         print(predictions_all ,ground_truths_all )
#         print(prediction, ground_truth ,(accuracy_score(ground_truth.flatten(),prediction.flatten())))
        

        # extract loss
        loss = outputs.loss
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
    accuracy = (accuracy_score(ground_truths_all,predictions_all))
    precision, recall, f1, _ = precision_recall_fscore_support(ground_truths_all,predictions_all, average='binary')
    print(accuracy, precision, recall, f1, _)

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 0: 100%|████████████████████| 427/427 [02:56<00:00,  2.42it/s, loss=0.715]


0.49655778526439137 0.49057605521635256 0.5490196078431373 0.5181550539744848 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 1: 100%|████████████████████| 427/427 [02:57<00:00,  2.40it/s, loss=0.671]


0.5085689175333236 0.5015882183078256 0.516042780748663 0.5087128422902328 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 2: 100%|████████████████████| 427/427 [02:55<00:00,  2.43it/s, loss=0.699]


0.5088618719789073 0.5020092735703245 0.48247177658942364 0.4920466595970307 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 3: 100%|████████████████████| 427/427 [02:54<00:00,  2.44it/s, loss=0.695]


0.5155998242273326 0.5093740069907848 0.476232917409388 0.4922462766774144 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 4: 100%|████████████████████| 427/427 [02:54<00:00,  2.44it/s, loss=0.699]


0.5098872125384503 0.5031525851197982 0.4741532976827095 0.48822269807280516 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 5: 100%|████████████████████| 427/427 [02:54<00:00,  2.44it/s, loss=0.701]


0.5210194814706313 0.5158415841584159 0.464349376114082 0.48874296435272047 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 6: 100%|████████████████████| 427/427 [02:54<00:00,  2.44it/s, loss=0.685]


0.5352277720814413 0.5297380585516178 0.5106951871657754 0.5200423536530026 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 7: 100%|████████████████████| 427/427 [02:55<00:00,  2.43it/s, loss=0.741]


0.5359601581954007 0.5321011673151751 0.4875222816399287 0.5088372093023256 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 8: 100%|████████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.705]


0.5400615204335726 0.5365459249676585 0.4928698752228164 0.5137813564571075 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 9: 100%|████████████████████| 427/427 [02:54<00:00,  2.45it/s, loss=0.596]


0.5462135637908305 0.5468859342197341 0.464349376114082 0.5022493573264781 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 10: 100%|███████████████████| 427/427 [02:54<00:00,  2.44it/s, loss=0.687]


0.5627654899663103 0.5729047072330654 0.4447415329768271 0.5007526342197692 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 11: 100%|███████████████████| 427/427 [02:58<00:00,  2.40it/s, loss=0.744]


0.5591035593965138 0.5683039140445126 0.43998811645870467 0.49598124581379777 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 12: 100%|████████████████████| 427/427 [03:04<00:00,  2.31it/s, loss=0.55]


0.5765343489087447 0.5838926174496645 0.4910873440285205 0.5334839438437955 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 13: 100%|███████████████████| 427/427 [03:04<00:00,  2.32it/s, loss=0.496]


0.6135930862750842 0.6374622356495468 0.5014854426619133 0.5613568340538743 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 14: 100%|███████████████████| 427/427 [02:59<00:00,  2.38it/s, loss=0.526]


0.6466969386260436 0.6773234200743494 0.5412953060011884 0.6017173051519155 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 15: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.515]


0.6745276109564963 0.7050179211469534 0.5843731431966727 0.6390513320337882 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 16: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.364]


0.7177383916800938 0.7541504768632992 0.6342840166369578 0.6890430853638857 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 17: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.521]


0.7413212245495825 0.7711864406779662 0.6758764111705288 0.720392653578214 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 18: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.569]


0.7690054196572433 0.805603006491288 0.7005347593582888 0.7494040997934212 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 19: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.641]


0.781455983594551 0.8180583842498302 0.7159833630421866 0.7636248415716095 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 20: 100%|████████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.42]


0.7980079097700308 0.8260584181161799 0.7477718360071302 0.7849680336815844 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 21: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.242]


0.811044382598506 0.837890625 0.7647058823529411 0.799627213420317 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 22: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.512]


0.8280357404423612 0.8519588953114965 0.7881758764111705 0.8188271604938273 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 23: 100%|██████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.0788]


0.8327230115717006 0.8589412524209167 0.7905525846702317 0.823329207920792 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 24: 100%|███████████████████| 427/427 [02:53<00:00,  2.47it/s, loss=0.315]


0.8423905082759631 0.8616550852811118 0.8104575163398693 0.835272504592774 None


  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 25:  19%|███▊                | 82/427 [00:33<02:25,  2.37it/s, loss=0.135]

In [None]:
print(ground_truths)