# Setup 

In [None]:
import torch
import datasets
import pandas as pd 
from tqdm import TqdmWarning
from tokenizers import ByteLevelBPETokenizer 
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers.integrations import WandbCallback
from transformers import RobertaConfig, RobertaTokenizerFast, RobertaForMaskedLM
from sklearn.model_selection import train_test_split

import os 
import json 
import pickle
import warnings
import subprocess 

# Hide all warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', category=TqdmWarning)

# Set up weights & biases 
os.environ["WANDB_PROJECT"] = "malbert-hf"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

# Config options 
MAX_LENGTH = 10         	    	# Max number of tokens in an instruction
VOCAB_SIZE = 10000      	    	# Number of tokens 
SUBSET_SIZE = 0.1       	    	# Size of dataset as a fraction of the total number of files (0 - 1)
HIDDEN_SIZE = 768       	    	# Size of hidden layers 
NUM_HIDDEN = 12         	    	# Number of hidden layers
NUM_ATTENTION = 12      	    	# Number of attention heads 
INTERMEDIATE_SIZE = 3072 	    	# Size of intermediate layers
HIDDEN_ACT = "gelu"     	    	# Activation function used in the hidden layers
HIDDEN_DROPOUT_PROB = 0.1      		# Dropout probability in hidden layers
ATTENTION_DROPOUT_PROB = 0.1      	# Dropout probability in attention mechanisms
INIT_RANGE = 0.02       	    	# Variance of initialisation 
LAYER_NORM_EPS = 1e-12  	    	# Epsilon value used by layernorm 

EPOCHS = 20                         # Number of training epochs


# Directory containing files of disassembled executables
DATASET_BASE = "/Volumes/New Volume/malware-detection-dataset/opcodes/disasm"

# Evaluation samples
EVAL_DS_PATH = "./data.pickle"

with open(os.path.join(DATASET_BASE, 'labels.json'), 'r') as dataset_file:
    dataset = json.load(dataset_file)
files = [os.path.join(DATASET_BASE, name) for name in dataset.keys()]
labels = list(dataset.values())

files, _, labels, _ = train_test_split(files, labels, test_size=1 - SUBSET_SIZE)
def get_no_lines(file):
    return int(subprocess.run(["wc", "-l", file], capture_output=True).stdout.decode().lstrip().split(" ")[0])

# print(f"{len(files)} files will be used in training")

# Tokenizer

Train a new tokenizer if we haven't already.

In [None]:
!mkdir MalBERT

if not os.path.exists("./MalBERT"):
    tokenizer = ByteLevelBPETokenizer()
    tokenizer.train(files=files, vocab_size=VOCAB_SIZE, min_frequency=2, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ])

    tokenizer.save_model('MalBERT')

# Dataset

Set up the dataset

## Un-tokenized dataset

In [None]:
roberta_tokenizer = RobertaTokenizerFast.from_pretrained('./MalBERT', max_len=MAX_LENGTH)

def tokenize_fn(line):
    return roberta_tokenizer(line['text'], truncation=True, padding="max_length", max_length=MAX_LENGTH)

train_files, test_files = train_test_split(files)
raw_dataset = datasets.load_dataset('text', data_files={
    "train": train_files, 
    "test": test_files
})

print(f"{len(raw_dataset['train'])} lines in training dataset")
print(f"{len(raw_dataset['test'])} lines in testing dataset")

## Tokenized dataset

WARNING: This cell may take a long time to run depending on the size of the dataset. 

In [None]:
if not os.path.exists("data/tokenized"):
    dataset = raw_dataset.map(tokenize_fn, batched=True, remove_columns=['text'], num_proc=8, batch_size=1024)
    dataset.save_to_disk("data/tokenized")
else: 
    dataset = datasets.load_from_disk("data/tokenized")

# Train

## Custom weights & biases callback 

In [None]:
# Weights & Biases callback to log evaluation samples
class LogPredictionsCallback(WandbCallback):
    def __init__(self, data_path, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        with open(data_path, 'rb') as data_file: 
            self.data = pickle.load(data_file)
        
    def on_train_end(self, args, state, control, **kwargs):
        model = kwargs['model']
        device = model.device 
        
        model.eval()
        with torch.no_grad():
            output = model(
                input_ids=torch.tensor(self.data['input_ids']).to(device),
                attention_mask=torch.tensor(self.data['attention_mask']).to(device),
                labels=torch.tensor(self.data['labels']).to(device)
            )

        self.data['input_ids'] = torch.tensor(self.data['input_ids']).detach()
        self.data['attention_mask'] = torch.tensor(self.data['attention_mask']).detach()
        self.data['labels'] = torch.tensor(self.data['labels']).detach()

        mask_pos = torch.where(self.data['input_ids'] == roberta_tokenizer.mask_token_id)

        input = torch.clone(self.data['input_ids'])
        actual = torch.clone(self.data['input_ids'])
        predicted = torch.clone(self.data['input_ids'])

        actual[actual == roberta_tokenizer.mask_token_id] = self.data['labels'][mask_pos[0], mask_pos[1]]
        predicted[predicted == roberta_tokenizer.mask_token_id] = output.logits[mask_pos[0], mask_pos[1], :].argmax(dim=-1).cpu()

        x = [roberta_tokenizer.decode(xi[~torch.isin(xi, torch.tensor([0, 1, 2, 3]))]) for xi in input]
        y = roberta_tokenizer.batch_decode(actual, skip_special_tokens=True)
        y_hat = roberta_tokenizer.batch_decode(predicted, skip_special_tokens=True)

        df = pd.DataFrame({"Input": x, "Actual": y, "Predicted": y_hat})
        table = self._wandb.Table(dataframe=df)
        self._wandb.log({"sample": table})

## Model Creation 

In [None]:
config = RobertaConfig(
    vocab_size=VOCAB_SIZE, 
    max_position_embeddings=MAX_LENGTH, 
    num_attention_heads=NUM_ATTENTION,
    num_hidden_layers=NUM_HIDDEN,
    type_vocab_size=1,
    hidden_size=HIDDEN_SIZE,
    intermediate_size=INTERMEDIATE_SIZE,
    hidden_act=HIDDEN_ACT,
    hidden_dropout_prob=HIDDEN_DROPOUT_PROB, 
    attention_probs_dropout_prob=ATTENTION_DROPOUT_PROB,
    initializer_range=INIT_RANGE, 
    layer_norm_eps=LAYER_NORM_EPS,
)

model = RobertaForMaskedLM(config=config)

## Training Config

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=roberta_tokenizer, mlm=True, mlm_probability=0.15)
callback = LogPredictionsCallback(EVAL_DS_PATH, roberta_tokenizer)

train_args = TrainingArguments(
    output_dir="./MalBERT",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64, 
    save_steps=10_000, 
    save_total_limit=2,
    prediction_loss_only=True,  
    report_to="wandb",
)

trainer = Trainer(
    model=model,
    args=train_args, 
    processing_class=roberta_tokenizer,
    data_collator=data_collator,
    train_dataset=dataset['train'], 
    eval_dataset=dataset['test']
)

trainer.add_callback(callback)

## Start training

In [None]:
trainer.train()
trainer.save_model("./MalBERT")

# Evaluation

Predictions made from the evaluation set

In [None]:
data = json.load(open("artifacts/run-159ev9rd-sample:v0/sample.table.json", "r"))
pd.DataFrame(columns=data['columns'], data=data['data'])