In [124]:
import datasets 
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from tokenizers.normalizers import NFKC
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import Tokenizer, trainers, models
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling
from transformers import RobertaForMaskedLM, RobertaConfig
from transformers import PreTrainedTokenizerFast, RobertaTokenizerFast

import os 
from typing import List
from collections import defaultdict

DATA_PATH = "/Volumes/New Volume/malware-detection-dataset/opcodes/processed-data"
MAX_LENGTH = 64

In [16]:
def get_data(path: os.PathLike, full_path: bool = True) -> List[str]:
    all_files = os.listdir(path)
    
    if full_path:
        return [os.path.join(path, file) for file in all_files if file.endswith('.txt') and not file.startswith("._")]
    else: 
        return all_files

def get_labels(filenames):
    return [1 if "VirusShare" in filename else 0 for filename in filenames]

paths = get_data(DATA_PATH)
labels = get_labels(paths)

In [17]:
class OpcodeDataset(Dataset): 
    def __init__(self, paths, labels):
        assert len(paths) == len(labels), "Mismatch between number of files and labels"
        self.paths = paths 
        self.labels = labels

    def __len__(self):
        return len(self.paths)        


    def __getitem__(self, idx):
        assert 0 <= idx <= len(self), "Index out of range"
        label = self.labels[idx]

        with open(self.paths[idx], 'r') as file: 
            content = file.readlines() 
            
        return ' '.join([opcode.rstrip() for opcode in content]), label

opcode_dataset = OpcodeDataset(paths, labels)

In [18]:
if not os.path.exists('./MalBERTa'):
    tokenizer = Tokenizer(models.WordLevel(unk_token="<unk>"))
    tokenizer.normalizer = NFKC()
    tokenizer.pre_tokenizer = Whitespace()

    trainer = trainers.WordLevelTrainer(
        vocab_size=1293, 
        special_tokens=[
            "<s>",
            "<pad>",
            "</s>",
            "<unk>",
            "<mask>",
        ], 
    )
    tokenizer.train(paths, trainer)
    tokenizer.save('MalBERTa/tokenizer.json')

    hf_tokenizer = PreTrainedTokenizerFast(
        tokenizer_file="MalBERTa/tokenizer.json",
        unk_token="<unk>",
        bos_token="<s>",
        eos_token="</s>",
        pad_token="<pad>",
        mask_token="<mask>"
    )
    hf_tokenizer.save_pretrained("MalBERTa")
    tokenizer = hf_tokenizer
else: 
    tokenizer = PreTrainedTokenizerFast.from_pretrained("MalBERTa")

In [None]:
def dataset_generator():
    for text, label in tqdm(opcode_dataset): 
        yield {
            "text": text,
            "label": label
        }

if not os.path.exists('./data/raw'):
    dataset = datasets.Dataset.from_generator(dataset_generator)
    dataset = dataset.train_test_split(test_size=0.2)
    dataset.save_to_disk("data/raw")
else: 
    dataset = datasets.load_from_disk("./data/raw")

Saving the dataset (6/6 shards): 100%|██████████| 5552/5552 [00:02<00:00, 1863.22 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 1388/1388 [00:00<00:00, 2638.87 examples/s]


In [122]:
def handle_sample(sample):
    texts = sample['text']
    labels = sample['label']
    
    flattened = defaultdict(list)

    for text, label in zip(texts, labels):
        tokenized = tokenizer(
            text,
            padding='max_length',
            max_length=MAX_LENGTH,
            return_overflowing_tokens=True,
            truncation=True,
            return_special_tokens_mask=True,
        )

        for i in range(len(tokenized['input_ids'])):
            for k in tokenized:
                flattened[k].append(tokenized[k][i])
            flattened['label'].append(label)

    return dict(flattened)

processed_dataset = dataset.map(
    handle_sample,
    remove_columns=dataset['test'].column_names,
    batch_size=64,
    batched=True,
    num_proc=8,
)


Python(314) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(317) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(318) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(321) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(322) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(323) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(324) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(326) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Map (num_proc=8):   0%|          | 0/5552 [00:00<?, ? examples/s]Python(328) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(330) MallocStackLogging: can't turn off malloc stack logging bec

RuntimeError: One of the subprocesses has abruptly died during map operation.To debug the error, disable multiprocessing.

In [None]:
config = RobertaConfig(
    vocab_size=tokenizer.vocab_size, 
    max_position_embeddings=MAX_LENGTH + 2, 
    num_attention_heads=4,
    num_hidden_layers=4,
    type_vocab_size=1,
    hidden_size=128,
    intermediate_size=2048,
)

model = RobertaForMaskedLM(config=config)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

train_ds = processed_dataset['train'].remove_columns('label')
test_ds = processed_dataset['test'].remove_columns('label')

train_args = TrainingArguments(
    output_dir="./MalBERTa",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64, 
    save_steps=10_000, 
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=train_args, 
    processing_class=tokenizer,
    data_collator=data_collator,
    train_dataset=train_ds,
    eval_dataset=test_ds,
)

# trainer.train()

In [121]:
import torch
import pandas as pd

def predict(token_ids):
    X = data_collator(torch.tensor(token_ids['input_ids']))
    preds = trainer.predict(X['input_ids'])
    
    Y_hat = tokenizer.batch_decode(preds.predictions.argmax(-1))
    Y = tokenizer.batch_decode(token_ids['input_ids'])

    df = pd.DataFrame(data={
        "Input": tokenizer.batch_decode(X['input_ids']),
        "Predicted": Y_hat,
        "Actual": Y,
    })

    return df
    
data = test_ds.select(range(10))
predict(data)

Unnamed: 0,Input,Predicted,Actual
0,push mov in sub in add nop mov inc or push mov...,maskmovdqu vprolvd vmovlps vpternlogq pfpnacc ...,push mov in sub in add nop mov inc or push mov...
1,add lea inc in push mov add push lea inc hlt p...,movddup fsubp fsubp phsubd vandnps wrmsr movdd...,add lea inc in push mov add push lea inc hlt p...
2,jmp and inc and inc add jmp and dec and inc ad...,vpternlogq fsubp fsubp fsubp psllw korw vpmadc...,jmp and inc and inc add jmp and dec and inc ad...
3,<mask> <mask> <mask> add add lea mov <mask> <m...,xtest vcmpltps fsubp vunpcklpd vfmadd213pd fsu...,ret lea mov add add lea mov add add nop sub in...
4,mov add inc <mask> push pshufhw add <mask> mov...,movabs lwpval fsubp vpternlogq vandnps vcmpneq...,mov add inc add push inc add lea mov add inc a...
5,<mask> cmp cmp sbb jne inc add add <mask> dec ...,xtest vcmpss fsubp fsubp vorps fsetpm movddup ...,add cmp cmp sbb jne inc add add je dec xor rol...
6,pavgusb test <mask> setne xchg <mask> jmp jae ...,xtest cmpltsd fsubp vpternlogq vandnps wrmsr v...,rol test ror setne xchg shr jmp jae dec lea jb...
7,in sub <mask> add rol inc add <mask> <mask> an...,movdqu fsubp fsubp vunpcklpd vcvttps2qq wrmsr ...,in sub mov add rol inc add mov inc and adc or ...
8,and add and in xor sar pextrd jno cld push mov...,maskmovdqu vcmpltps fsubp vpternlogq vandnps v...,and add and in xor sar push jno cld push mov i...
9,cmp je test sbb add add <mask> add and call ad...,movabs fsubp vcmpunordsd fsubp vandnps korw mu...,cmp je test sbb add add mov add and call add a...
